diff --git a/.circleci/unittest/linux/scripts/run_style_checks.sh b/.circleci/unittest/linux/scripts/run_style_checks.sh index d8af8824f5..93a7adb5f5 100755 --- a/.circleci/unittest/linux/scripts/run_style_checks.sh +++ b/.circleci/unittest/linux/scripts/run_style_checks.sh @@ -38,7 +38,7 @@ fi printf "\x1b[34mRunning clang-format:\x1b[0m\n" "${this_dir}"/run_clang_format.py \ - -r torchaudio/csrc \ + -r torchaudio/csrc third_party/kaldi/src \ --clang-format-executable "${clangformat_path}" \ && git diff --exit-code status=$? diff --git a/.gitmodules b/.gitmodules index c01f8c91ad..dd47fc1e91 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,3 +2,7 @@ path = third_party/transducer/submodule url = https://github.com/HawkAaron/warp-transducer ignore = dirty +[submodule "kaldi"] + path = third_party/kaldi/submodule + url = https://github.com/kaldi-asr/kaldi + ignore = dirty diff --git a/CMakeLists.txt b/CMakeLists.txt index d72639a71e..87abef1262 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -47,6 +47,7 @@ endif() # Options option(BUILD_SOX "Build libsox statically" OFF) +option(BUILD_KALDI "Build kaldi statically" ON) option(BUILD_TRANSDUCER "Enable transducer" OFF) option(BUILD_LIBTORCHAUDIO "Build C++ Library" ON) option(BUILD_TORCHAUDIO_PYTHON_EXTENSION "Build Python extension" OFF) diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index d34464ed37..3d3cb6d095 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -68,6 +68,7 @@ def build_extension(self, ext): '-DCMAKE_VERBOSE_MAKEFILE=ON', f"-DPython_INCLUDE_DIR={distutils.sysconfig.get_python_inc()}", f"-DBUILD_SOX:BOOL={'ON' if _BUILD_SOX else 'OFF'}", + "-DBUILD_KALDI:BOOL=ON", f"-DBUILD_TRANSDUCER:BOOL={'ON' if _BUILD_TRANSDUCER else 'OFF'}", "-DBUILD_TORCHAUDIO_PYTHON_EXTENSION:BOOL=ON", "-DBUILD_LIBTORCHAUDIO:BOOL=OFF", diff --git a/test/torchaudio_unittest/assets/kaldi_test_pitch_args.json b/test/torchaudio_unittest/assets/kaldi_test_pitch_args.json new file mode 100644 index 0000000000..9844bd6c72 --- /dev/null +++ b/test/torchaudio_unittest/assets/kaldi_test_pitch_args.json @@ -0,0 +1,5 @@ +{"sample_rate": 8000} +{"sample_rate": 8000, "frames_per_chunk": 200} +{"sample_rate": 8000, "frames_per_chunk": 200, "simulate_first_pass_online": true} +{"sample_rate": 16000} +{"sample_rate": 44100} diff --git a/test/torchaudio_unittest/common_utils/kaldi_utils.py b/test/torchaudio_unittest/common_utils/kaldi_utils.py new file mode 100644 index 0000000000..ed8fc07e36 --- /dev/null +++ b/test/torchaudio_unittest/common_utils/kaldi_utils.py @@ -0,0 +1,39 @@ +import subprocess + +import torch + + +def convert_args(**kwargs): + args = [] + for key, value in kwargs.items(): + if key == 'sample_rate': + key = 'sample_frequency' + key = '--' + key.replace('_', '-') + value = str(value).lower() if value in [True, False] else str(value) + args.append('%s=%s' % (key, value)) + return args + + +def run_kaldi(command, input_type, input_value): + """Run provided Kaldi command, pass a tensor and get the resulting tensor + + Args: + input_type: str + 'ark' or 'scp' + input_value: + Tensor for 'ark' + string for 'scp' (path to an audio file) + """ + import kaldi_io + + key = 'foo' + process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) + if input_type == 'ark': + kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key) + elif input_type == 'scp': + process.stdin.write(f'{key} {input_value}'.encode('utf8')) + else: + raise NotImplementedError('Unexpected type') + process.stdin.close() + result = dict(kaldi_io.read_mat_ark(process.stdout))['foo'] + return torch.from_numpy(result.copy()) # copy supresses some torch warning diff --git a/test/torchaudio_unittest/functional/batch_consistency_test.py b/test/torchaudio_unittest/functional/batch_consistency_test.py index 7e48fcbbb7..1e6509d80f 100644 --- a/test/torchaudio_unittest/functional/batch_consistency_test.py +++ b/test/torchaudio_unittest/functional/batch_consistency_test.py @@ -184,3 +184,9 @@ def test_vad(self): waveform, sample_rate = torchaudio.load(filepath) self.assert_batch_consistencies( F.vad, waveform, sample_rate=sample_rate) + + @common_utils.skipIfNoExtension + def test_compute_kaldi_pitch(self): + sample_rate = 44100 + waveform = common_utils.get_whitenoise(sample_rate=sample_rate) + self.assert_batch_consistencies(F.compute_kaldi_pitch, waveform, sample_rate=sample_rate) diff --git a/test/torchaudio_unittest/functional/kaldi_compatibility_cpu_test.py b/test/torchaudio_unittest/functional/kaldi_compatibility_cpu_test.py new file mode 100644 index 0000000000..d8d6895a8f --- /dev/null +++ b/test/torchaudio_unittest/functional/kaldi_compatibility_cpu_test.py @@ -0,0 +1,9 @@ +import torch + +from torchaudio_unittest.common_utils import PytorchTestCase +from .kaldi_compatibility_test_impl import KaldiCPUOnly + + +class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/functional/kaldi_compatibility_test_impl.py b/test/torchaudio_unittest/functional/kaldi_compatibility_test_impl.py new file mode 100644 index 0000000000..0d802281f2 --- /dev/null +++ b/test/torchaudio_unittest/functional/kaldi_compatibility_test_impl.py @@ -0,0 +1,37 @@ +from parameterized import parameterized +import torchaudio.functional as F + +from torchaudio_unittest.common_utils import ( + get_sinusoid, + load_params, + save_wav, + skipIfNoExec, + TempDirMixin, + TestBaseMixin, +) +from torchaudio_unittest.common_utils.kaldi_utils import ( + convert_args, + run_kaldi, +) + + +class KaldiCPUOnly(TempDirMixin, TestBaseMixin): + def assert_equal(self, output, *, expected, rtol=None, atol=None): + expected = expected.to(dtype=self.dtype, device=self.device) + self.assertEqual(output, expected, rtol=rtol, atol=atol) + + @parameterized.expand(load_params('kaldi_test_pitch_args.json')) + @skipIfNoExec('compute-kaldi-pitch-feats') + def test_pitch_feats(self, kwargs): + """compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats""" + sample_rate = kwargs['sample_rate'] + waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate) + result = F.compute_kaldi_pitch(waveform[0], **kwargs) + + waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate) + wave_file = self.get_temp_path('test.wav') + save_wav(wave_file, waveform, sample_rate) + + command = ['compute-kaldi-pitch-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'scp', wave_file) + self.assert_equal(result, expected=kaldi_result) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 83f7e8e95d..e00d4c8df5 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -547,3 +547,15 @@ def func(tensor): tensor = common_utils.get_whitenoise(sample_rate=44100) self._assert_consistency(func, tensor) + + @common_utils.skipIfNoExtension + def test_compute_kaldi_pitch(self): + if self.dtype != torch.float32 or self.device != torch.device('cpu'): + raise unittest.SkipTest("Only float32, cpu is supported.") + + def func(tensor): + sample_rate: float = 44100. + return F.compute_kaldi_pitch(tensor, sample_rate) + + tensor = common_utils.get_whitenoise(sample_rate=44100) + self._assert_consistency(func, tensor) diff --git a/test/torchaudio_unittest/kaldi_compatibility_impl.py b/test/torchaudio_unittest/kaldi_compatibility_impl.py index 983805fff5..e160936a2d 100644 --- a/test/torchaudio_unittest/kaldi_compatibility_impl.py +++ b/test/torchaudio_unittest/kaldi_compatibility_impl.py @@ -1,7 +1,4 @@ """Test suites for checking numerical compatibility against Kaldi""" -import subprocess - -import kaldi_io import torch import torchaudio.functional as F import torchaudio.compliance.kaldi @@ -9,46 +6,19 @@ from torchaudio_unittest.common_utils import ( TestBaseMixin, + TempDirMixin, load_params, skipIfNoExec, get_asset_path, - load_wav + load_wav, +) +from torchaudio_unittest.common_utils.kaldi_utils import ( + convert_args, + run_kaldi, ) -def _convert_args(**kwargs): - args = [] - for key, value in kwargs.items(): - key = '--' + key.replace('_', '-') - value = str(value).lower() if value in [True, False] else str(value) - args.append('%s=%s' % (key, value)) - return args - - -def _run_kaldi(command, input_type, input_value): - """Run provided Kaldi command, pass a tensor and get the resulting tensor - - Args: - input_type: str - 'ark' or 'scp' - input_value: - Tensor for 'ark' - string for 'scp' (path to an audio file) - """ - key = 'foo' - process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) - if input_type == 'ark': - kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key) - elif input_type == 'scp': - process.stdin.write(f'{key} {input_value}'.encode('utf8')) - else: - raise NotImplementedError('Unexpected type') - process.stdin.close() - result = dict(kaldi_io.read_mat_ark(process.stdout))['foo'] - return torch.from_numpy(result.copy()) # copy supresses some torch warning - - -class Kaldi(TestBaseMixin): +class Kaldi(TempDirMixin, TestBaseMixin): def assert_equal(self, output, *, expected, rtol=None, atol=None): expected = expected.to(dtype=self.dtype, device=self.device) self.assertEqual(output, expected, rtol=rtol, atol=atol) @@ -65,8 +35,8 @@ def test_sliding_window_cmn(self): tensor = torch.randn(40, 10, dtype=self.dtype, device=self.device) result = F.sliding_window_cmn(tensor, **kwargs) - command = ['apply-cmvn-sliding'] + _convert_args(**kwargs) + ['ark:-', 'ark:-'] - kaldi_result = _run_kaldi(command, 'ark', tensor) + command = ['apply-cmvn-sliding'] + convert_args(**kwargs) + ['ark:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'ark', tensor) self.assert_equal(result, expected=kaldi_result) @parameterized.expand(load_params('kaldi_test_fbank_args.json')) @@ -76,8 +46,8 @@ def test_fbank(self, kwargs): wave_file = get_asset_path('kaldi_file.wav') waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) result = torchaudio.compliance.kaldi.fbank(waveform, **kwargs) - command = ['compute-fbank-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] - kaldi_result = _run_kaldi(command, 'scp', wave_file) + command = ['compute-fbank-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'scp', wave_file) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) @parameterized.expand(load_params('kaldi_test_spectrogram_args.json')) @@ -87,8 +57,8 @@ def test_spectrogram(self, kwargs): wave_file = get_asset_path('kaldi_file.wav') waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) result = torchaudio.compliance.kaldi.spectrogram(waveform, **kwargs) - command = ['compute-spectrogram-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] - kaldi_result = _run_kaldi(command, 'scp', wave_file) + command = ['compute-spectrogram-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'scp', wave_file) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) @parameterized.expand(load_params('kaldi_test_mfcc_args.json')) @@ -98,6 +68,6 @@ def test_mfcc(self, kwargs): wave_file = get_asset_path('kaldi_file.wav') waveform = load_wav(wave_file, normalize=False)[0].to(dtype=self.dtype, device=self.device) result = torchaudio.compliance.kaldi.mfcc(waveform, **kwargs) - command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] - kaldi_result = _run_kaldi(command, 'scp', wave_file) + command = ['compute-mfcc-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] + kaldi_result = run_kaldi(command, 'scp', wave_file) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index 3e1ab962fe..38acd4b52f 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -17,6 +17,14 @@ else() endif() list(APPEND TORCHAUDIO_THIRD_PARTIES libsox) +################################################################################ +# kaldi +################################################################################ +if (BUILD_KALDI) + add_subdirectory(kaldi) + list(APPEND TORCHAUDIO_THIRD_PARTIES kaldi) +endif() + ################################################################################ # transducer ################################################################################ diff --git a/third_party/kaldi/CMakeLists.txt b/third_party/kaldi/CMakeLists.txt new file mode 100644 index 0000000000..6f891f8d1a --- /dev/null +++ b/third_party/kaldi/CMakeLists.txt @@ -0,0 +1,30 @@ +set(KALDI_REPO ${CMAKE_CURRENT_SOURCE_DIR}/submodule) + +# Apply custom patch +execute_process( + WORKING_DIRECTORY ${KALDI_REPO} + COMMAND "git" "checkout" "." + ) +execute_process( + WORKING_DIRECTORY ${KALDI_REPO} + COMMAND git apply ../kaldi.patch + ) +# Update the version string +execute_process( + WORKING_DIRECTORY ${KALDI_REPO}/src/base + COMMAND sh get_version.sh + ) + +set(KALDI_SOURCES + src/matrix/kaldi-vector.cc + src/matrix/kaldi-matrix.cc + submodule/src/base/kaldi-error.cc + submodule/src/base/kaldi-math.cc + submodule/src/feat/feature-functions.cc + submodule/src/feat/pitch-functions.cc + submodule/src/feat/resample.cc + ) + +add_library(kaldi STATIC ${KALDI_SOURCES}) +target_include_directories(kaldi PUBLIC src submodule/src) +target_link_libraries(kaldi ${TORCH_LIBRARIES}) diff --git a/third_party/kaldi/README.md b/third_party/kaldi/README.md new file mode 100644 index 0000000000..58c48747a8 --- /dev/null +++ b/third_party/kaldi/README.md @@ -0,0 +1,6 @@ +# Custom Kaldi build + +This directory contains original Kaldi repository (as submodule), [the custom implementation of Kaldi's vector/matrix](./src) and the build script. + +We use the custom build process so that the resulting library only contains what torchaudio needs. +We use the custom vector/matrix implementation so that we can use the same BLAS library that PyTorch is compiled with, and so that we can (hopefully, in future) take advantage of other PyTorch features (such as differentiability and GPU support). The down side of this approach is that it adds a lot of overhead compared to the original Kaldi (operator dispatch and element-wise processing, which PyTorch is not efficient at). We can improve this gradually, and if you are interested in helping, please let us know by opening an issue. \ No newline at end of file diff --git a/third_party/kaldi/kaldi.patch b/third_party/kaldi/kaldi.patch new file mode 100644 index 0000000000..40667bced8 --- /dev/null +++ b/third_party/kaldi/kaldi.patch @@ -0,0 +1,76 @@ +diff --git a/src/base/kaldi-types.h b/src/base/kaldi-types.h +index 7ebf4f853..c15b288b2 100644 +--- a/src/base/kaldi-types.h ++++ b/src/base/kaldi-types.h +@@ -41,6 +41,7 @@ typedef float BaseFloat; + + // for discussion on what to do if you need compile kaldi + // without OpenFST, see the bottom of this this file ++/* + #include + + namespace kaldi { +@@ -53,10 +54,10 @@ namespace kaldi { + typedef float float32; + typedef double double64; + } // end namespace kaldi ++*/ + + // In a theoretical case you decide compile Kaldi without the OpenFST + // comment the previous namespace statement and uncomment the following +-/* + namespace kaldi { + typedef int8_t int8; + typedef int16_t int16; +@@ -70,6 +71,5 @@ namespace kaldi { + typedef float float32; + typedef double double64; + } // end namespace kaldi +-*/ + + #endif // KALDI_BASE_KALDI_TYPES_H_ +diff --git a/src/matrix/matrix-lib.h b/src/matrix/matrix-lib.h +index b6059b06c..4fb9e1b16 100644 +--- a/src/matrix/matrix-lib.h ++++ b/src/matrix/matrix-lib.h +@@ -25,14 +25,14 @@ + #include "base/kaldi-common.h" + #include "matrix/kaldi-vector.h" + #include "matrix/kaldi-matrix.h" +-#include "matrix/sp-matrix.h" +-#include "matrix/tp-matrix.h" ++// #include "matrix/sp-matrix.h" ++// #include "matrix/tp-matrix.h" + #include "matrix/matrix-functions.h" + #include "matrix/srfft.h" + #include "matrix/compressed-matrix.h" +-#include "matrix/sparse-matrix.h" ++// #include "matrix/sparse-matrix.h" + #include "matrix/optimization.h" +-#include "matrix/numpy-array.h" ++// #include "matrix/numpy-array.h" + + #endif + +diff --git a/src/util/common-utils.h b/src/util/common-utils.h +index cfb0c255c..48d199e97 100644 +--- a/src/util/common-utils.h ++++ b/src/util/common-utils.h +@@ -21,11 +21,11 @@ + + #include "base/kaldi-common.h" + #include "util/parse-options.h" +-#include "util/kaldi-io.h" +-#include "util/simple-io-funcs.h" +-#include "util/kaldi-holder.h" +-#include "util/kaldi-table.h" +-#include "util/table-types.h" +-#include "util/text-utils.h" ++// #include "util/kaldi-io.h" ++// #include "util/simple-io-funcs.h" ++// #include "util/kaldi-holder.h" ++// #include "util/kaldi-table.h" ++// #include "util/table-types.h" ++// #include "util/text-utils.h" + + #endif // KALDI_UTIL_COMMON_UTILS_H_ diff --git a/third_party/kaldi/src/matrix/kaldi-matrix.cc b/third_party/kaldi/src/matrix/kaldi-matrix.cc new file mode 100644 index 0000000000..a89c3809c9 --- /dev/null +++ b/third_party/kaldi/src/matrix/kaldi-matrix.cc @@ -0,0 +1,39 @@ +#include "matrix/kaldi-matrix.h" +#include + +namespace { + +template +void assert_matrix_shape(const torch::Tensor& tensor_); + +template <> +void assert_matrix_shape(const torch::Tensor& tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +template <> +void assert_matrix_shape(const torch::Tensor& tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +} // namespace + +namespace kaldi { + +template +MatrixBase::MatrixBase(torch::Tensor tensor) : tensor_(tensor) { + assert_matrix_shape(tensor_); +}; + +template class Matrix; +template class Matrix; +template class MatrixBase; +template class MatrixBase; +template class SubMatrix; +template class SubMatrix; + +} // namespace kaldi diff --git a/third_party/kaldi/src/matrix/kaldi-matrix.h b/third_party/kaldi/src/matrix/kaldi-matrix.h new file mode 100644 index 0000000000..f64828b84f --- /dev/null +++ b/third_party/kaldi/src/matrix/kaldi-matrix.h @@ -0,0 +1,178 @@ +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h + +#ifndef KALDI_MATRIX_KALDI_MATRIX_H_ +#define KALDI_MATRIX_KALDI_MATRIX_H_ + +#include +#include "matrix/kaldi-vector.h" +#include "matrix/matrix-common.h" + +using namespace torch::indexing; + +namespace kaldi { + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L44-L48 +template +class MatrixBase { + public: + //////////////////////////////////////////////////////////////////////////////// + // PyTorch-specific items + //////////////////////////////////////////////////////////////////////////////// + torch::Tensor tensor_; + /// Construct VectorBase which is an interface to an existing torch::Tensor + /// object. + MatrixBase(torch::Tensor tensor); + + //////////////////////////////////////////////////////////////////////////////// + // Kaldi-compatible items + //////////////////////////////////////////////////////////////////////////////// + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L62-L63 + inline MatrixIndexT NumRows() const { + return tensor_.size(0); + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L65-L66 + inline MatrixIndexT NumCols() const { + return tensor_.size(1); + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L177-L178 + void CopyColFromVec(const VectorBase& v, const MatrixIndexT col) { + tensor_.index_put_({Slice(), col}, v.tensor_); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L99-L107 + inline Real& operator()(MatrixIndexT r, MatrixIndexT c) { + // CPU only + return tensor_.accessor()[r][c]; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L112-L120 + inline const Real operator()(MatrixIndexT r, MatrixIndexT c) const { + return tensor_.index({Slice(r), Slice(c)}).item().template to(); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L138-L141 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L859-L898 + template + void CopyFromMat( + const MatrixBase& M, + MatrixTransposeType trans = kNoTrans) { + auto src = M.tensor_; + if (trans == kTrans) + src = src.transpose(1, 0); + tensor_.index_put_({Slice(), Slice()}, src); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L186-L191 + inline const SubVector Row(MatrixIndexT i) const { + return SubVector(*this, i); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L208-L211 + inline SubMatrix RowRange( + const MatrixIndexT row_offset, + const MatrixIndexT num_rows) const { + return SubMatrix(*this, row_offset, num_rows, 0, NumCols()); + } + + protected: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L749-L753 + explicit MatrixBase() : tensor_(torch::empty({0, 0})) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L781-L784 +template +class Matrix : public MatrixBase { + public: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L786-L787 + Matrix() : MatrixBase() {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L789-L793 + Matrix( + const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride) + : MatrixBase() { + Resize(r, c, resize_type, stride_type); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L808-L811 + explicit Matrix( + const MatrixBase& M, + MatrixTransposeType trans = kNoTrans) + : MatrixBase( + trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L816-L819 + template + explicit Matrix( + const MatrixBase& M, + MatrixTransposeType trans = kNoTrans) + : MatrixBase( + trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L859-L874 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L817-L857 + void Resize( + const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride) { + auto& tensor_ = MatrixBase::tensor_; + switch (resize_type) { + case kSetZero: + tensor_.resize_({r, c}).zero_(); + break; + case kUndefined: + tensor_.resize_({r, c}); + break; + case kCopyData: + auto tmp = tensor_; + auto tmp_rows = tmp.size(0); + auto tmp_cols = tmp.size(1); + tensor_.resize_({r, c}).zero_(); + auto rows = Slice(None, r < tmp_rows ? r : tmp_rows); + auto cols = Slice(None, c < tmp_cols ? c : tmp_cols); + tensor_.index_put_({rows, cols}, tmp.index({rows, cols})); + break; + } + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L876-L883 + Matrix& operator=(const MatrixBase& other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L940-L948 +template +class SubMatrix : public MatrixBase { + public: + SubMatrix( + const MatrixBase& T, + const MatrixIndexT ro, // row offset, 0 < ro < NumRows() + const MatrixIndexT r, // number of rows, r > 0 + const MatrixIndexT co, // column offset, 0 < co < NumCols() + const MatrixIndexT c) // number of columns, c > 0 + : MatrixBase( + T.tensor_.index({Slice(ro, ro + r), Slice(co, co + c)})) {} +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L1059-L1060 +template +std::ostream& operator<<(std::ostream& Out, const MatrixBase& M) { + Out << M.tensor_; + return Out; +} + +} // namespace kaldi + +#endif diff --git a/third_party/kaldi/src/matrix/kaldi-vector.cc b/third_party/kaldi/src/matrix/kaldi-vector.cc new file mode 100644 index 0000000000..df59f13a36 --- /dev/null +++ b/third_party/kaldi/src/matrix/kaldi-vector.cc @@ -0,0 +1,42 @@ +#include "matrix/kaldi-vector.h" +#include +#include "matrix/kaldi-matrix.h" + +namespace { + +template +void assert_vector_shape(const torch::Tensor& tensor_); + +template <> +void assert_vector_shape(const torch::Tensor& tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +template <> +void assert_vector_shape(const torch::Tensor& tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +} // namespace + +namespace kaldi { + +template +VectorBase::VectorBase(torch::Tensor tensor) + : tensor_(tensor), data_(tensor.data_ptr()) { + assert_vector_shape(tensor_); +}; + +template +VectorBase::VectorBase() : VectorBase(torch::empty({0})) {} + +template class Vector; +template class Vector; +template class VectorBase; +template class VectorBase; + +} // namespace kaldi diff --git a/third_party/kaldi/src/matrix/kaldi-vector.h b/third_party/kaldi/src/matrix/kaldi-vector.h new file mode 100644 index 0000000000..620f3676d3 --- /dev/null +++ b/third_party/kaldi/src/matrix/kaldi-vector.h @@ -0,0 +1,313 @@ +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h + +#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ +#define KALDI_MATRIX_KALDI_VECTOR_H_ + +#include +#include "matrix/matrix-common.h" + +using namespace torch::indexing; + +namespace kaldi { + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L36-L40 +template +class VectorBase { + public: + //////////////////////////////////////////////////////////////////////////////// + // PyTorch-specific things + //////////////////////////////////////////////////////////////////////////////// + torch::Tensor tensor_; + + /// Construct VectorBase which is an interface to an existing torch::Tensor + /// object. + VectorBase(torch::Tensor tensor); + + //////////////////////////////////////////////////////////////////////////////// + // Kaldi-compatible methods + //////////////////////////////////////////////////////////////////////////////// + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L42-L43 + void SetZero() { + Set(0); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L48-L49 + void Set(Real f) { + tensor_.index_put_({"..."}, f); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L62-L63 + inline MatrixIndexT Dim() const { + return tensor_.numel(); + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L68-L69 + inline Real* Data() { + return data_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L71-L72 + inline const Real* Data() const { + return data_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L74-L79 + inline Real operator()(MatrixIndexT i) const { + return data_[i]; + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L81-L86 + inline Real& operator()(MatrixIndexT i) { + return tensor_.accessor()[i]; + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L88-L95 + SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector(*this, o, l); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L97-L105 + const SubVector Range(const MatrixIndexT o, const MatrixIndexT l) + const { + return SubVector(*this, o, l); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L107-L108 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L226-L233 + void CopyFromVec(const VectorBase& v) { + TORCH_INTERNAL_ASSERT(tensor_.sizes() == v.tensor_.sizes()); + tensor_.copy_(v.tensor_); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L137-L139 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L816-L832 + void ApplyFloor(Real floor_val, MatrixIndexT* floored_count = nullptr) { + auto index = tensor_ < floor_val; + auto tmp = tensor_.index_put_({index}, floor_val); + if (floored_count) { + *floored_count = index.sum().item().template to(); + } + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L164-L165 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L449-L479 + void ApplyPow(Real power) { + tensor_.pow_(power); + TORCH_INTERNAL_ASSERT(!tensor_.isnan().sum().item().template to()); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L181-L184 + template + void AddVec(const Real alpha, const VectorBase& v) { + tensor_ += alpha * v.tensor_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L186-L187 + void AddVec2(const Real alpha, const VectorBase& v) { + tensor_ += alpha * (v.tensor_.square()); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L196-L198 + void AddMatVec( + const Real alpha, + const MatrixBase& M, + const MatrixTransposeType trans, + const VectorBase& v, + const Real beta) { // **beta previously defaulted to 0.0** + auto mat = M.tensor_; + if (trans == kTrans) { + mat = mat.transpose(1, 0); + } + tensor_.addmv_(mat, v.tensor_, beta, alpha); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L221-L222 + void MulElements(const VectorBase& v) { + tensor_ *= v.tensor_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L233-L234 + void Add(Real c) { + tensor_ += c; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L236-L239 + void AddVecVec( + Real alpha, + const VectorBase& v, + const VectorBase& r, + Real beta) { + tensor_ = beta * tensor_ + alpha * v.tensor_ * r.tensor_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L246-L247 + void Scale(Real alpha) { + tensor_ *= alpha; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L305-L306 + Real Min() const { + if (tensor_.numel()) { + return tensor_.min().item().template to(); + } + return std::numeric_limits::infinity(); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L308-L310 + Real Min(MatrixIndexT* index) const { + TORCH_INTERNAL_ASSERT(tensor_.numel()); + torch::Tensor value, ind; + std::tie(value, ind) = tensor_.min(0); + *index = ind.item().to(); + return value.item().to(); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L312-L313 + Real Sum() const { + return tensor_.sum().item().template to(); + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L320-L321 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L718-L736 + void AddRowSumMat(Real alpha, const MatrixBase& M, Real beta = 1.0) { + Vector ones(M.NumRows()); + ones.Set(1.0); + this->AddMatVec(alpha, M, kTrans, ones, beta); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L323-L324 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L738-L757 + void AddColSumMat(Real alpha, const MatrixBase& M, Real beta = 1.0) { + Vector ones(M.NumCols()); + ones.Set(1.0); + this->AddMatVec(alpha, M, kNoTrans, ones, beta); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L326-L330 + void AddDiagMat2( + Real alpha, + const MatrixBase& M, + MatrixTransposeType trans = kNoTrans, + Real beta = 1.0) { + auto mat = M.tensor_; + if (trans == kNoTrans) { + tensor_ = + beta * tensor_ + torch::diag(torch::mm(mat, mat.transpose(1, 0))); + } else { + tensor_ = + beta * tensor_ + torch::diag(torch::mm(mat.transpose(1, 0), mat)); + } + } + + protected: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L362-L365 + explicit VectorBase(); + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L378-L379 + Real* data_; + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L382 + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L385-L390 +template +class Vector : public VectorBase { + public: + //////////////////////////////////////////////////////////////////////////////// + // PyTorch-compatibility things + //////////////////////////////////////////////////////////////////////////////// + /// Construct VectorBase which is an interface to an existing torch::Tensor + /// object. + Vector(torch::Tensor tensor) : VectorBase(tensor){}; + + //////////////////////////////////////////////////////////////////////////////// + // Kaldi-compatible methods + //////////////////////////////////////////////////////////////////////////////// + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L392-L393 + Vector() : VectorBase(){}; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L395-L399 + explicit Vector(const MatrixIndexT s, MatrixResizeType resize_type = kSetZero) + : VectorBase() { + Resize(s, resize_type); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L406-L410 + // Note: unlike the original implementation, this is "explicit". + explicit Vector(const Vector& v) + : VectorBase(v.tensor_.clone()) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L412-L416 + explicit Vector(const VectorBase& v) + : VectorBase(v.tensor_.clone()) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L434-L435 + void Swap(Vector* other) { + auto tmp = VectorBase::tensor_; + this->tensor_ = other->tensor_; + other->tensor_ = tmp; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L444-L451 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L189-L223 + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero) { + auto& tensor_ = this->tensor_; + switch (resize_type) { + case kSetZero: + tensor_.resize_({length}).zero_(); + break; + case kUndefined: + tensor_.resize_({length}); + break; + case kCopyData: + auto tmp = tensor_; + auto tmp_numel = tensor_.numel(); + tensor_.resize_({length}).zero_(); + auto numel = Slice(length < tmp_numel ? length : tmp_numel); + tensor_.index_put_({numel}, tmp.index({numel})); + break; + } + // data_ptr() causes compiler error + this->data_ = static_cast(tensor_.data_ptr()); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L463-L468 + Vector& operator=(const VectorBase& other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L482-L485 +template +class SubVector : public VectorBase { + public: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L487-L499 + SubVector( + const VectorBase& t, + const MatrixIndexT origin, + const MatrixIndexT length) + : VectorBase(t.tensor_.index({Slice(origin, origin + length)})) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L524-L528 + SubVector(const MatrixBase& matrix, MatrixIndexT row) + : VectorBase(matrix.tensor_.index({row})) {} +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L540-L543 +template +std::ostream& operator<<(std::ostream& out, const VectorBase& v) { + out << v.tensor_; + return out; +} + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L573-L575 +template +Real VecVec(const VectorBase& v1, const VectorBase& v2) { + return torch::dot(v1.tensor_, v2.tensor_).item().template to(); +} + +} // namespace kaldi + +#endif diff --git a/third_party/kaldi/submodule b/third_party/kaldi/submodule new file mode 160000 index 0000000000..3eea37dd09 --- /dev/null +++ b/third_party/kaldi/submodule @@ -0,0 +1 @@ +Subproject commit 3eea37dd09b55064e6362216f7e9a60641f29f09 diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index f9c3a4d08f..409d42f57f 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -15,6 +15,10 @@ if(BUILD_TRANSDUCER) list(APPEND LIBTORCHAUDIO_SOURCES transducer.cpp) endif() +if(BUILD_KALDI) + list(APPEND LIBTORCHAUDIO_SOURCES kaldi.cpp) +endif() + ################################################################################ # libtorchaudio.so ################################################################################ diff --git a/torchaudio/csrc/kaldi.cpp b/torchaudio/csrc/kaldi.cpp new file mode 100644 index 0000000000..6f2b36c28f --- /dev/null +++ b/torchaudio/csrc/kaldi.cpp @@ -0,0 +1,93 @@ +#include +#include "feat/pitch-functions.h" + +namespace torchaudio { +namespace kaldi { + +namespace { + +torch::Tensor denormalize(const torch::Tensor& t) { + auto ret = t; + auto pos = t > 0, neg = t < 0; + ret.index_put({pos}, t.index({pos}) * 32767); + ret.index_put({neg}, t.index({neg}) * 32768); + return ret; +} + +torch::Tensor compute_kaldi_pitch( + const torch::Tensor& wave, + const ::kaldi::PitchExtractionOptions& opts) { + ::kaldi::VectorBase<::kaldi::BaseFloat> input(wave); + ::kaldi::Matrix<::kaldi::BaseFloat> output; + ::kaldi::ComputeKaldiPitch(opts, input, &output); + return output.tensor_; +} + +} // namespace + +torch::Tensor ComputeKaldiPitch( + const torch::Tensor& wave, + double sample_frequency, + double frame_length, + double frame_shift, + double min_f0, + double max_f0, + double soft_min_f0, + double penalty_factor, + double lowpass_cutoff, + double resample_frequency, + double delta_pitch, + double nccf_ballast, + int64_t lowpass_filter_width, + int64_t upsample_filter_width, + int64_t max_frames_latency, + int64_t frames_per_chunk, + bool simulate_first_pass_online, + int64_t recompute_frame, + bool snip_edges) { + TORCH_CHECK(wave.ndimension() == 2, "Input tensor must be 2 dimentional."); + TORCH_CHECK(wave.device().is_cpu(), "Input tensor must be on CPU."); + TORCH_CHECK( + wave.dtype() == torch::kFloat32, "Input tensor must be float32 type."); + + ::kaldi::PitchExtractionOptions opts; + opts.samp_freq = static_cast<::kaldi::BaseFloat>(sample_frequency); + opts.frame_shift_ms = static_cast<::kaldi::BaseFloat>(frame_shift); + opts.frame_length_ms = static_cast<::kaldi::BaseFloat>(frame_length); + opts.min_f0 = static_cast<::kaldi::BaseFloat>(min_f0); + opts.max_f0 = static_cast<::kaldi::BaseFloat>(max_f0); + opts.soft_min_f0 = static_cast<::kaldi::BaseFloat>(soft_min_f0); + opts.penalty_factor = static_cast<::kaldi::BaseFloat>(penalty_factor); + opts.lowpass_cutoff = static_cast<::kaldi::BaseFloat>(lowpass_cutoff); + opts.resample_freq = static_cast<::kaldi::BaseFloat>(resample_frequency); + opts.delta_pitch = static_cast<::kaldi::BaseFloat>(delta_pitch); + opts.lowpass_filter_width = static_cast<::kaldi::int32>(lowpass_filter_width); + opts.upsample_filter_width = + static_cast<::kaldi::int32>(upsample_filter_width); + opts.max_frames_latency = static_cast<::kaldi::int32>(max_frames_latency); + opts.frames_per_chunk = static_cast<::kaldi::int32>(frames_per_chunk); + opts.simulate_first_pass_online = simulate_first_pass_online; + opts.recompute_frame = static_cast<::kaldi::int32>(recompute_frame); + opts.snip_edges = snip_edges; + + // Kaldi's float type expects value range of int16 expressed as float + torch::Tensor wave_ = denormalize(wave); + + auto batch_size = wave_.size(0); + std::vector results(batch_size); + at::parallel_for(0, batch_size, 1, [&](int64_t begin, int64_t end) { + for (auto i = begin; i < end; ++i) { + results[i] = compute_kaldi_pitch(wave_.index({i}), opts); + } + }); + return torch::stack(results, 0); +} + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def( + "torchaudio::kaldi_ComputeKaldiPitch", + &torchaudio::kaldi::ComputeKaldiPitch); +} + +} // namespace kaldi +} // namespace torchaudio diff --git a/torchaudio/functional/__init__.py b/torchaudio/functional/__init__.py index d59400e82e..0d29147be3 100644 --- a/torchaudio/functional/__init__.py +++ b/torchaudio/functional/__init__.py @@ -3,6 +3,7 @@ angle, complex_norm, compute_deltas, + compute_kaldi_pitch, create_dct, create_fb_matrix, DB_to_amplitude, @@ -47,6 +48,7 @@ 'angle', 'complex_norm', 'compute_deltas', + 'compute_kaldi_pitch', 'create_dct', 'create_fb_matrix', 'DB_to_amplitude', diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 4dde324a31..cbc5293731 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -13,6 +13,7 @@ "amplitude_to_DB", "DB_to_amplitude", "compute_deltas", + "compute_kaldi_pitch", "create_fb_matrix", "create_dct", "compute_deltas", @@ -991,3 +992,105 @@ def spectral_centroid( device=specgram.device).reshape((-1, 1)) freq_dim = -2 return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) + + +def compute_kaldi_pitch( + waveform: torch.Tensor, + sample_rate: float, + frame_length: float = 25.0, + frame_shift: float = 10.0, + min_f0: float = 50, + max_f0: float = 400, + soft_min_f0: float = 10.0, + penalty_factor: float = 0.1, + lowpass_cutoff: float = 1000, + resample_frequency: float = 4000, + delta_pitch: float = 0.005, + nccf_ballast: float = 7000, + lowpass_filter_width: int = 1, + upsample_filter_width: int = 5, + max_frames_latency: int = 0, + frames_per_chunk: int = 0, + simulate_first_pass_online: bool = False, + recompute_frame: int = 500, + snip_edges: bool = True, +) -> torch.Tensor: + """Extract pitch based on method described in [1]. + + This function computes the equivalent of `compute-kaldi-pitch-feats` from Kaldi. + + Args: + waveform (Tensor): + The input waveform of shape `(..., time)`. + sample_rate (float): + Sample rate of `waveform`. + frame_length (float, optional): + Frame length in milliseconds. + frame_shift (float, optional): + Frame shift in milliseconds. + min_f0 (float, optional): + Minimum F0 to search for (Hz) + max_f0 (float, optional): + Maximum F0 to search for (Hz) + soft_min_f0 (float, optional): + Minimum f0, applied in soft way, must not exceed min-f0 + penalty_factor (float, optional): + Cost factor for FO change. + lowpass_cutoff (float, optional): + Cutoff frequency for LowPass filter (Hz) + resample_frequency (float, optional): + Frequency that we down-sample the signal to. Must be more than twice lowpass-cutoff. + delta_pitch( float, optional): + Smallest relative change in pitch that our algorithm measures. + nccf_ballast (float, optional): + Increasing this factor reduces NCCF for quiet frames + lowpass_filter_width (int, optional): + Integer that determines filter width of lowpass filter, more gives sharper filter. + upsample_filter_width (int, optional): + Integer that determines filter width when upsampling NCCF. + max_frames_latency (int, optional): + Maximum number of frames of latency that we allow pitch tracking to introduce into + the feature processing (affects output only if ``frames_per_chunk > 0`` and + ``simulate_first_pass_online=True``) + frames_per_chunk (int, optional): + The number of frames used for energy normalization. + simulate_first_pass_online (bool, optional): + If true, the function will output features that correspond to what an online decoder + would see in the first pass of decoding -- not the final version of the features, + which is the default. + Relevant if ``frames_per_chunk > 0``. + recompute_frame (int, optional): + Only relevant for compatibility with online pitch extraction. + A non-critical parameter; the frame at which we recompute some of the forward pointers, + after revising our estimate of the signal energy. + Relevant if ``frames_per_chunk > 0``. + snip_edges (bool, optional): + If this is set to false, the incomplete frames near the ending edge won't be snipped, + so that the number of frames is the file size divided by the frame-shift. + This makes different types of features give the same number of frames. + + Returns: + Tensor: Pitch feature. Shape: `(batch, frames 2)` where the last dimension + corresponds to pitch and NCCF. + + Reference: + - A pitch extraction algorithm tuned for automatic speech recognition + + P. Ghahremani, B. BabaAli, D. Povey, K. Riedhammer, J. Trmal and S. Khudanpur + + 2014 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP), + + Florence, 2014, pp. 2494-2498, doi: 10.1109/ICASSP.2014.6854049. + """ + shape = waveform.shape + waveform = waveform.reshape(-1, shape[-1]) + result = torch.ops.torchaudio.kaldi_ComputeKaldiPitch( + waveform, sample_rate, frame_length, frame_shift, + min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff, + resample_frequency, delta_pitch, nccf_ballast, + lowpass_filter_width, upsample_filter_width, max_frames_latency, + frames_per_chunk, simulate_first_pass_online, recompute_frame, + snip_edges, + ) + result = result.reshape(shape[:-1] + result.shape[-2:]) + return result