From 1a72ebca98cd8ed9f9342e23a009d44cef332b58 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 2 Dec 2020 00:43:13 +0000 Subject: [PATCH] Add ComputeKaldiPitch --- build_tools/setup_helpers/extension.py | 3 + .../assets/kaldi_test_pitch_args.json | 5 + .../kaldi_compatibility_cpu_test.py | 7 +- .../kaldi_compatibility_impl.py | 27 +- torchaudio/csrc/kaldi/base/io-funcs-inl.h | 327 ++++ torchaudio/csrc/kaldi/base/io-funcs.h | 245 +++ torchaudio/csrc/kaldi/base/kaldi-common.h | 41 + torchaudio/csrc/kaldi/base/kaldi-error.cc | 249 +++ torchaudio/csrc/kaldi/base/kaldi-error.h | 231 +++ torchaudio/csrc/kaldi/base/kaldi-math.cc | 162 ++ torchaudio/csrc/kaldi/base/kaldi-math.h | 363 ++++ torchaudio/csrc/kaldi/base/kaldi-types.h | 75 + torchaudio/csrc/kaldi/base/kaldi-utils.h | 155 ++ torchaudio/csrc/kaldi/base/timer.h | 115 ++ torchaudio/csrc/kaldi/base/version.h | 4 + .../csrc/kaldi/feat/feature-common-inl.h | 99 + torchaudio/csrc/kaldi/feat/feature-common.h | 176 ++ torchaudio/csrc/kaldi/feat/feature-fbank.h | 149 ++ .../csrc/kaldi/feat/feature-functions.cc | 362 ++++ .../csrc/kaldi/feat/feature-functions.h | 204 ++ torchaudio/csrc/kaldi/feat/feature-mfcc.h | 154 ++ torchaudio/csrc/kaldi/feat/feature-plp.h | 176 ++ torchaudio/csrc/kaldi/feat/feature-window.h | 223 +++ torchaudio/csrc/kaldi/feat/mel-computations.h | 171 ++ torchaudio/csrc/kaldi/feat/online-feature.h | 632 +++++++ torchaudio/csrc/kaldi/feat/pitch-functions.cc | 1667 +++++++++++++++++ torchaudio/csrc/kaldi/feat/pitch-functions.h | 450 +++++ torchaudio/csrc/kaldi/feat/resample.cc | 377 ++++ torchaudio/csrc/kaldi/feat/resample.h | 287 +++ .../csrc/kaldi/itf/online-feature-itf.h | 125 ++ torchaudio/csrc/kaldi/itf/options-itf.h | 49 + torchaudio/csrc/kaldi/kaldi.cc | 69 + torchaudio/csrc/kaldi/kaldi.h | 35 + .../csrc/kaldi/matrix/compressed-matrix.h | 283 +++ torchaudio/csrc/kaldi/matrix/kaldi-matrix.cc | 36 + torchaudio/csrc/kaldi/matrix/kaldi-matrix.h | 163 ++ torchaudio/csrc/kaldi/matrix/kaldi-vector.cc | 43 + torchaudio/csrc/kaldi/matrix/kaldi-vector.h | 281 +++ torchaudio/csrc/kaldi/matrix/matrix-common.h | 111 ++ .../csrc/kaldi/matrix/matrix-functions-inl.h | 56 + .../csrc/kaldi/matrix/matrix-functions.h | 174 ++ torchaudio/csrc/kaldi/matrix/matrix-lib.h | 38 + torchaudio/csrc/kaldi/matrix/optimization.h | 248 +++ torchaudio/csrc/kaldi/matrix/srfft.h | 141 ++ torchaudio/csrc/kaldi/register.cpp | 10 + torchaudio/csrc/kaldi/util/common-utils.h | 31 + torchaudio/csrc/kaldi/util/parse-options.h | 264 +++ torchaudio/functional/__init__.py | 1 + torchaudio/functional/functional.py | 35 + 49 files changed, 9326 insertions(+), 3 deletions(-) create mode 100644 test/torchaudio_unittest/assets/kaldi_test_pitch_args.json create mode 100644 torchaudio/csrc/kaldi/base/io-funcs-inl.h create mode 100644 torchaudio/csrc/kaldi/base/io-funcs.h create mode 100644 torchaudio/csrc/kaldi/base/kaldi-common.h create mode 100644 torchaudio/csrc/kaldi/base/kaldi-error.cc create mode 100644 torchaudio/csrc/kaldi/base/kaldi-error.h create mode 100644 torchaudio/csrc/kaldi/base/kaldi-math.cc create mode 100644 torchaudio/csrc/kaldi/base/kaldi-math.h create mode 100644 torchaudio/csrc/kaldi/base/kaldi-types.h create mode 100644 torchaudio/csrc/kaldi/base/kaldi-utils.h create mode 100644 torchaudio/csrc/kaldi/base/timer.h create mode 100644 torchaudio/csrc/kaldi/base/version.h create mode 100644 torchaudio/csrc/kaldi/feat/feature-common-inl.h create mode 100644 torchaudio/csrc/kaldi/feat/feature-common.h create mode 100644 torchaudio/csrc/kaldi/feat/feature-fbank.h create mode 100644 torchaudio/csrc/kaldi/feat/feature-functions.cc create mode 100644 torchaudio/csrc/kaldi/feat/feature-functions.h create mode 100644 torchaudio/csrc/kaldi/feat/feature-mfcc.h create mode 100644 torchaudio/csrc/kaldi/feat/feature-plp.h create mode 100644 torchaudio/csrc/kaldi/feat/feature-window.h create mode 100644 torchaudio/csrc/kaldi/feat/mel-computations.h create mode 100644 torchaudio/csrc/kaldi/feat/online-feature.h create mode 100644 torchaudio/csrc/kaldi/feat/pitch-functions.cc create mode 100644 torchaudio/csrc/kaldi/feat/pitch-functions.h create mode 100644 torchaudio/csrc/kaldi/feat/resample.cc create mode 100644 torchaudio/csrc/kaldi/feat/resample.h create mode 100644 torchaudio/csrc/kaldi/itf/online-feature-itf.h create mode 100644 torchaudio/csrc/kaldi/itf/options-itf.h create mode 100644 torchaudio/csrc/kaldi/kaldi.cc create mode 100644 torchaudio/csrc/kaldi/kaldi.h create mode 100644 torchaudio/csrc/kaldi/matrix/compressed-matrix.h create mode 100644 torchaudio/csrc/kaldi/matrix/kaldi-matrix.cc create mode 100644 torchaudio/csrc/kaldi/matrix/kaldi-matrix.h create mode 100644 torchaudio/csrc/kaldi/matrix/kaldi-vector.cc create mode 100644 torchaudio/csrc/kaldi/matrix/kaldi-vector.h create mode 100644 torchaudio/csrc/kaldi/matrix/matrix-common.h create mode 100644 torchaudio/csrc/kaldi/matrix/matrix-functions-inl.h create mode 100644 torchaudio/csrc/kaldi/matrix/matrix-functions.h create mode 100644 torchaudio/csrc/kaldi/matrix/matrix-lib.h create mode 100644 torchaudio/csrc/kaldi/matrix/optimization.h create mode 100644 torchaudio/csrc/kaldi/matrix/srfft.h create mode 100644 torchaudio/csrc/kaldi/register.cpp create mode 100644 torchaudio/csrc/kaldi/util/common-utils.h create mode 100644 torchaudio/csrc/kaldi/util/parse-options.h diff --git a/build_tools/setup_helpers/extension.py b/build_tools/setup_helpers/extension.py index 7f0ba7a14f9..44cadcd7cfd 100644 --- a/build_tools/setup_helpers/extension.py +++ b/build_tools/setup_helpers/extension.py @@ -64,6 +64,8 @@ def _get_ela(debug): def _get_srcs(): srcs = [_CSRC_DIR / 'pybind.cpp'] srcs += list(_CSRC_DIR.glob('sox/**/*.cpp')) + srcs += list(_CSRC_DIR.glob('kaldi/**/*.cpp')) + srcs += list(_CSRC_DIR.glob('kaldi/**/*.cc')) if _BUILD_TRANSDUCER: srcs += [_CSRC_DIR / 'transducer.cpp'] return [str(path) for path in srcs] @@ -72,6 +74,7 @@ def _get_srcs(): def _get_include_dirs(): dirs = [ str(_ROOT_DIR), + str(_CSRC_DIR / 'kaldi'), ] if _BUILD_SOX or _BUILD_TRANSDUCER: dirs.append(str(_TP_INSTALL_DIR / 'include')) diff --git a/test/torchaudio_unittest/assets/kaldi_test_pitch_args.json b/test/torchaudio_unittest/assets/kaldi_test_pitch_args.json new file mode 100644 index 00000000000..fef292e6168 --- /dev/null +++ b/test/torchaudio_unittest/assets/kaldi_test_pitch_args.json @@ -0,0 +1,5 @@ +{"sample_frequency": 8000} +{"sample_frequency": 8000, "frames_per_chunk": 200} +{"sample_frequency": 8000, "frames_per_chunk": 200, "simulate_first_pass_online": true} +{"sample_frequency": 16000} +{"sample_frequency": 44100} diff --git a/test/torchaudio_unittest/kaldi_compatibility_cpu_test.py b/test/torchaudio_unittest/kaldi_compatibility_cpu_test.py index 43be412b479..24897ab373b 100644 --- a/test/torchaudio_unittest/kaldi_compatibility_cpu_test.py +++ b/test/torchaudio_unittest/kaldi_compatibility_cpu_test.py @@ -1,7 +1,7 @@ import torch from torchaudio_unittest import common_utils -from .kaldi_compatibility_impl import Kaldi +from .kaldi_compatibility_impl import Kaldi, KaldiCPUOnly class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase): @@ -12,3 +12,8 @@ class TestKaldiFloat32(Kaldi, common_utils.PytorchTestCase): class TestKaldiFloat64(Kaldi, common_utils.PytorchTestCase): dtype = torch.float64 device = torch.device('cpu') + + +class TestKaldiCPUOnly(KaldiCPUOnly, common_utils.PytorchTestCase): + dtype = torch.float32 + device = torch.device('cpu') diff --git a/test/torchaudio_unittest/kaldi_compatibility_impl.py b/test/torchaudio_unittest/kaldi_compatibility_impl.py index 983805fff5d..a8b5fef43bd 100644 --- a/test/torchaudio_unittest/kaldi_compatibility_impl.py +++ b/test/torchaudio_unittest/kaldi_compatibility_impl.py @@ -9,10 +9,13 @@ from torchaudio_unittest.common_utils import ( TestBaseMixin, + TempDirMixin, load_params, skipIfNoExec, get_asset_path, - load_wav + load_wav, + save_wav, + get_sinusoid, ) @@ -48,11 +51,13 @@ def _run_kaldi(command, input_type, input_value): return torch.from_numpy(result.copy()) # copy supresses some torch warning -class Kaldi(TestBaseMixin): +class KaldiTestBase(TempDirMixin, TestBaseMixin): def assert_equal(self, output, *, expected, rtol=None, atol=None): expected = expected.to(dtype=self.dtype, device=self.device) self.assertEqual(output, expected, rtol=rtol, atol=atol) + +class Kaldi(KaldiTestBase): @skipIfNoExec('apply-cmvn-sliding') def test_sliding_window_cmn(self): """sliding_window_cmn should be numerically compatible with apply-cmvn-sliding""" @@ -101,3 +106,21 @@ def test_mfcc(self, kwargs): command = ['compute-mfcc-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] kaldi_result = _run_kaldi(command, 'scp', wave_file) self.assert_equal(result, expected=kaldi_result, rtol=1e-4, atol=1e-8) + + +class KaldiCPUOnly(KaldiTestBase): + @parameterized.expand(load_params('kaldi_test_pitch_args.json')) + @skipIfNoExec('compute-kaldi-pitch-feats') + def test_pitch_feats(self, kwargs): + """compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats""" + sample_rate = kwargs['sample_frequency'] + waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate) + result = F.compute_kaldi_pitch(waveform[0], **kwargs) + + waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate) + wave_file = self.get_temp_path('test.wav') + save_wav(wave_file, waveform, sample_rate) + + command = ['compute-kaldi-pitch-feats'] + _convert_args(**kwargs) + ['scp:-', 'ark:-'] + kaldi_result = _run_kaldi(command, 'scp', wave_file) + self.assert_equal(result, expected=kaldi_result) diff --git a/torchaudio/csrc/kaldi/base/io-funcs-inl.h b/torchaudio/csrc/kaldi/base/io-funcs-inl.h new file mode 100644 index 00000000000..b703ef5addc --- /dev/null +++ b/torchaudio/csrc/kaldi/base/io-funcs-inl.h @@ -0,0 +1,327 @@ +// base/io-funcs-inl.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian; +// Johns Hopkins University (Author: Daniel Povey) +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_IO_FUNCS_INL_H_ +#define KALDI_BASE_IO_FUNCS_INL_H_ 1 + +// Do not include this file directly. It is included by base/io-funcs.h + +#include +#include + +namespace kaldi { + +// Template that covers integers. +template void WriteBasicType(std::ostream &os, + bool binary, T t) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char len_c = (std::numeric_limits::is_signed ? 1 : -1) + * static_cast(sizeof(t)); + os.put(len_c); + os.write(reinterpret_cast(&t), sizeof(t)); + } else { + if (sizeof(t) == 1) + os << static_cast(t) << " "; + else + os << t << " "; + } + if (os.fail()) { + KALDI_ERR << "Write failure in WriteBasicType."; + } +} + +// Template that covers integers. +template inline void ReadBasicType(std::istream &is, + bool binary, T *t) { + KALDI_PARANOID_ASSERT(t != NULL); + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + int len_c_in = is.get(); + if (len_c_in == -1) + KALDI_ERR << "ReadBasicType: encountered end of stream."; + char len_c = static_cast(len_c_in), len_c_expected + = (std::numeric_limits::is_signed ? 1 : -1) + * static_cast(sizeof(*t)); + if (len_c != len_c_expected) { + KALDI_ERR << "ReadBasicType: did not get expected integer type, " + << static_cast(len_c) + << " vs. " << static_cast(len_c_expected) + << ". You can change this code to successfully" + << " read it later, if needed."; + // insert code here to read "wrong" type. Might have a switch statement. + } + is.read(reinterpret_cast(t), sizeof(*t)); + } else { + if (sizeof(*t) == 1) { + int16 i; + is >> i; + *t = i; + } else { + is >> *t; + } + } + if (is.fail()) { + KALDI_ERR << "Read failure in ReadBasicType, file position is " + << is.tellg() << ", next char is " << is.peek(); + } +} + +// Template that covers integers. +template +inline void WriteIntegerPairVector(std::ostream &os, bool binary, + const std::vector > &v) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char sz = sizeof(T); // this is currently just a check. + os.write(&sz, 1); + int32 vecsz = static_cast(v.size()); + KALDI_ASSERT((size_t)vecsz == v.size()); + os.write(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (vecsz != 0) { + os.write(reinterpret_cast(&(v[0])), sizeof(T) * vecsz * 2); + } + } else { + // focus here is on prettiness of text form rather than + // efficiency of reading-in. + // reading-in is dominated by low-level operations anyway: + // for efficiency use binary. + os << "[ "; + typename std::vector >::const_iterator iter = v.begin(), + end = v.end(); + for (; iter != end; ++iter) { + if (sizeof(T) == 1) + os << static_cast(iter->first) << ',' + << static_cast(iter->second) << ' '; + else + os << iter->first << ',' + << iter->second << ' '; + } + os << "]\n"; + } + if (os.fail()) { + KALDI_ERR << "Write failure in WriteIntegerPairVector."; + } +} + +// Template that covers integers. +template +inline void ReadIntegerPairVector(std::istream &is, bool binary, + std::vector > *v) { + KALDI_ASSERT_IS_INTEGER_TYPE(T); + KALDI_ASSERT(v != NULL); + if (binary) { + int sz = is.peek(); + if (sz == sizeof(T)) { + is.get(); + } else { // this is currently just a check. + KALDI_ERR << "ReadIntegerPairVector: expected to see type of size " + << sizeof(T) << ", saw instead " << sz << ", at file position " + << is.tellg(); + } + int32 vecsz; + is.read(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (is.fail() || vecsz < 0) goto bad; + v->resize(vecsz); + if (vecsz > 0) { + is.read(reinterpret_cast(&((*v)[0])), sizeof(T)*vecsz*2); + } + } else { + std::vector > tmp_v; // use temporary so v doesn't use extra memory + // due to resizing. + is >> std::ws; + if (is.peek() != static_cast('[')) { + KALDI_ERR << "ReadIntegerPairVector: expected to see [, saw " + << is.peek() << ", at file position " << is.tellg(); + } + is.get(); // consume the '['. + is >> std::ws; // consume whitespace. + while (is.peek() != static_cast(']')) { + if (sizeof(T) == 1) { // read/write chars as numbers. + int16 next_t1, next_t2; + is >> next_t1; + if (is.fail()) goto bad; + if (is.peek() != static_cast(',')) + KALDI_ERR << "ReadIntegerPairVector: expected to see ',', saw " + << is.peek() << ", at file position " << is.tellg(); + is.get(); // consume the ','. + is >> next_t2 >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back(std::make_pair((T)next_t1, (T)next_t2)); + } else { + T next_t1, next_t2; + is >> next_t1; + if (is.fail()) goto bad; + if (is.peek() != static_cast(',')) + KALDI_ERR << "ReadIntegerPairVector: expected to see ',', saw " + << is.peek() << ", at file position " << is.tellg(); + is.get(); // consume the ','. + is >> next_t2 >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back(std::pair(next_t1, next_t2)); + } + } + is.get(); // get the final ']'. + *v = tmp_v; // could use std::swap to use less temporary memory, but this + // uses less permanent memory. + } + if (!is.fail()) return; + bad: + KALDI_ERR << "ReadIntegerPairVector: read failure at file position " + << is.tellg(); +} + +template inline void WriteIntegerVector(std::ostream &os, bool binary, + const std::vector &v) { + // Compile time assertion that this is not called with a wrong type. + KALDI_ASSERT_IS_INTEGER_TYPE(T); + if (binary) { + char sz = sizeof(T); // this is currently just a check. + os.write(&sz, 1); + int32 vecsz = static_cast(v.size()); + KALDI_ASSERT((size_t)vecsz == v.size()); + os.write(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (vecsz != 0) { + os.write(reinterpret_cast(&(v[0])), sizeof(T)*vecsz); + } + } else { + // focus here is on prettiness of text form rather than + // efficiency of reading-in. + // reading-in is dominated by low-level operations anyway: + // for efficiency use binary. + os << "[ "; + typename std::vector::const_iterator iter = v.begin(), end = v.end(); + for (; iter != end; ++iter) { + if (sizeof(T) == 1) + os << static_cast(*iter) << " "; + else + os << *iter << " "; + } + os << "]\n"; + } + if (os.fail()) { + KALDI_ERR << "Write failure in WriteIntegerVector."; + } +} + + +template inline void ReadIntegerVector(std::istream &is, + bool binary, + std::vector *v) { + KALDI_ASSERT_IS_INTEGER_TYPE(T); + KALDI_ASSERT(v != NULL); + if (binary) { + int sz = is.peek(); + if (sz == sizeof(T)) { + is.get(); + } else { // this is currently just a check. + KALDI_ERR << "ReadIntegerVector: expected to see type of size " + << sizeof(T) << ", saw instead " << sz << ", at file position " + << is.tellg(); + } + int32 vecsz; + is.read(reinterpret_cast(&vecsz), sizeof(vecsz)); + if (is.fail() || vecsz < 0) goto bad; + v->resize(vecsz); + if (vecsz > 0) { + is.read(reinterpret_cast(&((*v)[0])), sizeof(T)*vecsz); + } + } else { + std::vector tmp_v; // use temporary so v doesn't use extra memory + // due to resizing. + is >> std::ws; + if (is.peek() != static_cast('[')) { + KALDI_ERR << "ReadIntegerVector: expected to see [, saw " + << is.peek() << ", at file position " << is.tellg(); + } + is.get(); // consume the '['. + is >> std::ws; // consume whitespace. + while (is.peek() != static_cast(']')) { + if (sizeof(T) == 1) { // read/write chars as numbers. + int16 next_t; + is >> next_t >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back((T)next_t); + } else { + T next_t; + is >> next_t >> std::ws; + if (is.fail()) goto bad; + else + tmp_v.push_back(next_t); + } + } + is.get(); // get the final ']'. + *v = tmp_v; // could use std::swap to use less temporary memory, but this + // uses less permanent memory. + } + if (!is.fail()) return; + bad: + KALDI_ERR << "ReadIntegerVector: read failure at file position " + << is.tellg(); +} + + +// Initialize an opened stream for writing by writing an optional binary +// header and modifying the floating-point precision. +inline void InitKaldiOutputStream(std::ostream &os, bool binary) { + // This does not throw exceptions (does not check for errors). + if (binary) { + os.put('\0'); + os.put('B'); + } + // Note, in non-binary mode we may at some point want to mess with + // the precision a bit. + // 7 is a bit more than the precision of float.. + if (os.precision() < 7) + os.precision(7); +} + +/// Initialize an opened stream for reading by detecting the binary header and +// setting the "binary" value appropriately. +inline bool InitKaldiInputStream(std::istream &is, bool *binary) { + // Sets the 'binary' variable. + // Throws exception in the very unusual situation that stream + // starts with '\0' but not then 'B'. + + if (is.peek() == '\0') { // seems to be binary + is.get(); + if (is.peek() != 'B') { + return false; + } + is.get(); + *binary = true; + return true; + } else { + *binary = false; + return true; + } +} + +} // end namespace kaldi. + +#endif // KALDI_BASE_IO_FUNCS_INL_H_ diff --git a/torchaudio/csrc/kaldi/base/io-funcs.h b/torchaudio/csrc/kaldi/base/io-funcs.h new file mode 100644 index 00000000000..895f661ecee --- /dev/null +++ b/torchaudio/csrc/kaldi/base/io-funcs.h @@ -0,0 +1,245 @@ +// base/io-funcs.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian +// 2016 Xiaohui Zhang + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_IO_FUNCS_H_ +#define KALDI_BASE_IO_FUNCS_H_ + +// This header only contains some relatively low-level I/O functions. +// The full Kaldi I/O declarations are in ../util/kaldi-io.h +// and ../util/kaldi-table.h +// They were put in util/ in order to avoid making the Matrix library +// dependent on them. + +#include +#include +#include + +#include "base/kaldi-common.h" +#include "base/io-funcs-inl.h" + +namespace kaldi { + + + +/* + This comment describes the Kaldi approach to I/O. All objects can be written + and read in two modes: binary and text. In addition we want to make the I/O + work if we redefine the typedef "BaseFloat" between floats and doubles. + We also want to have control over whitespace in text mode without affecting + the meaning of the file, for pretty-printing purposes. + + Errors are handled by throwing a KaldiFatalError exception. + + For integer and floating-point types (and boolean values): + + WriteBasicType(std::ostream &, bool binary, const T&); + ReadBasicType(std::istream &, bool binary, T*); + + and we expect these functions to be defined in such a way that they work when + the type T changes between float and double, so you can read float into double + and vice versa]. Note that for efficiency and space-saving reasons, the Vector + and Matrix classes do not use these functions [but they preserve the type + interchangeability in their own way] + + For a class (or struct) C: + class C { + .. + Write(std::ostream &, bool binary, [possibly extra optional args for specific classes]) const; + Read(std::istream &, bool binary, [possibly extra optional args for specific classes]); + .. + } + NOTE: The only actual optional args we used are the "add" arguments in + Vector/Matrix classes, which specify whether we should sum the data already + in the class with the data being read. + + For types which are typedef's involving stl classes, I/O is as follows: + typedef std::vector > MyTypedefName; + + The user should define something like: + + WriteMyTypedefName(std::ostream &, bool binary, const MyTypedefName &t); + ReadMyTypedefName(std::ostream &, bool binary, MyTypedefName *t); + + The user would have to write these functions. + + For a type std::vector: + + void WriteIntegerVector(std::ostream &os, bool binary, const std::vector &v); + void ReadIntegerVector(std::istream &is, bool binary, std::vector *v); + + For other types, e.g. vectors of pairs, the user should create a routine of the + type WriteMyTypedefName. This is to avoid introducing confusing templated functions; + we could easily create templated functions to handle most of these cases but they + would have to share the same name. + + It also often happens that the user needs to write/read special tokens as part + of a file. These might be class headers, or separators/identifiers in the class. + We provide special functions for manipulating these. These special tokens must + be nonempty and must not contain any whitespace. + + void WriteToken(std::ostream &os, bool binary, const char*); + void WriteToken(std::ostream &os, bool binary, const std::string & token); + int Peek(std::istream &is, bool binary); + void ReadToken(std::istream &is, bool binary, std::string *str); + void PeekToken(std::istream &is, bool binary, std::string *str); + + WriteToken writes the token and one space (whether in binary or text mode). + + Peek returns the first character of the next token, by consuming whitespace + (in text mode) and then returning the peek() character. It returns -1 at EOF; + it doesn't throw. It's useful if a class can have various forms based on + typedefs and virtual classes, and wants to know which version to read. + + ReadToken allows the caller to obtain the next token. PeekToken works just + like ReadToken, but seeks back to the beginning of the token. A subsequent + call to ReadToken will read the same token again. This is useful when + different object types are written to the same file; using PeekToken one can + decide which of the objects to read. + + There is currently no special functionality for writing/reading strings (where the strings + contain data rather than "special tokens" that are whitespace-free and nonempty). This is + because Kaldi is structured in such a way that strings don't appear, except as OpenFst symbol + table entries (and these have their own format). + + + NOTE: you should not call ReadIntegerType and WriteIntegerType with types, + such as int and size_t, that are machine-independent -- at least not + if you want your file formats to port between machines. Use int32 and + int64 where necessary. There is no way to detect this using compile-time + assertions because C++ only keeps track of the internal representation of + the type. +*/ + +/// \addtogroup io_funcs_basic +/// @{ + + +/// WriteBasicType is the name of the write function for bool, integer types, +/// and floating-point types. They all throw on error. +template void WriteBasicType(std::ostream &os, bool binary, T t); + +/// ReadBasicType is the name of the read function for bool, integer types, +/// and floating-point types. They all throw on error. +template void ReadBasicType(std::istream &is, bool binary, T *t); + + +// Declare specialization for bool. +template<> +void WriteBasicType(std::ostream &os, bool binary, bool b); + +template <> +void ReadBasicType(std::istream &is, bool binary, bool *b); + +// Declare specializations for float and double. +template<> +void WriteBasicType(std::ostream &os, bool binary, float f); + +template<> +void WriteBasicType(std::ostream &os, bool binary, double f); + +template<> +void ReadBasicType(std::istream &is, bool binary, float *f); + +template<> +void ReadBasicType(std::istream &is, bool binary, double *f); + +// Define ReadBasicType that accepts an "add" parameter to add to +// the destination. Caution: if used in Read functions, be careful +// to initialize the parameters concerned to zero in the default +// constructor. +template +inline void ReadBasicType(std::istream &is, bool binary, T *t, bool add) { + if (!add) { + ReadBasicType(is, binary, t); + } else { + T tmp = T(0); + ReadBasicType(is, binary, &tmp); + *t += tmp; + } +} + +/// Function for writing STL vectors of integer types. +template inline void WriteIntegerVector(std::ostream &os, bool binary, + const std::vector &v); + +/// Function for reading STL vector of integer types. +template inline void ReadIntegerVector(std::istream &is, bool binary, + std::vector *v); + +/// Function for writing STL vectors of pairs of integer types. +template +inline void WriteIntegerPairVector(std::ostream &os, bool binary, + const std::vector > &v); + +/// Function for reading STL vector of pairs of integer types. +template +inline void ReadIntegerPairVector(std::istream &is, bool binary, + std::vector > *v); + +/// The WriteToken functions are for writing nonempty sequences of non-space +/// characters. They are not for general strings. +void WriteToken(std::ostream &os, bool binary, const char *token); +void WriteToken(std::ostream &os, bool binary, const std::string & token); + +/// Peek consumes whitespace (if binary == false) and then returns the peek() +/// value of the stream. +int Peek(std::istream &is, bool binary); + +/// ReadToken gets the next token and puts it in str (exception on failure). If +/// PeekToken() had been previously called, it is possible that the stream had +/// failed to unget the starting '<' character. In this case ReadToken() returns +/// the token string without the leading '<'. You must be prepared to handle +/// this case. ExpectToken() handles this internally, and is not affected. +void ReadToken(std::istream &is, bool binary, std::string *token); + +/// PeekToken will return the first character of the next token, or -1 if end of +/// file. It's the same as Peek(), except if the first character is '<' it will +/// skip over it and will return the next character. It will attempt to unget +/// the '<' so the stream is where it was before you did PeekToken(), however, +/// this is not guaranteed (see ReadToken()). +int PeekToken(std::istream &is, bool binary); + +/// ExpectToken tries to read in the given token, and throws an exception +/// on failure. +void ExpectToken(std::istream &is, bool binary, const char *token); +void ExpectToken(std::istream &is, bool binary, const std::string & token); + +/// ExpectPretty attempts to read the text in "token", but only in non-binary +/// mode. Throws exception on failure. It expects an exact match except that +/// arbitrary whitespace matches arbitrary whitespace. +void ExpectPretty(std::istream &is, bool binary, const char *token); +void ExpectPretty(std::istream &is, bool binary, const std::string & token); + +/// @} end "addtogroup io_funcs_basic" + + +/// InitKaldiOutputStream initializes an opened stream for writing by writing an +/// optional binary header and modifying the floating-point precision; it will +/// typically not be called by users directly. +inline void InitKaldiOutputStream(std::ostream &os, bool binary); + +/// InitKaldiInputStream initializes an opened stream for reading by detecting +/// the binary header and setting the "binary" value appropriately; +/// It will typically not be called by users directly. +inline bool InitKaldiInputStream(std::istream &is, bool *binary); + +} // end namespace kaldi. +#endif // KALDI_BASE_IO_FUNCS_H_ diff --git a/torchaudio/csrc/kaldi/base/kaldi-common.h b/torchaudio/csrc/kaldi/base/kaldi-common.h new file mode 100644 index 00000000000..264565d1812 --- /dev/null +++ b/torchaudio/csrc/kaldi/base/kaldi-common.h @@ -0,0 +1,41 @@ +// base/kaldi-common.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_COMMON_H_ +#define KALDI_BASE_KALDI_COMMON_H_ 1 + +#include +#include +#include // C string stuff like strcpy +#include +#include +#include +#include +#include +#include +#include + +#include "base/kaldi-utils.h" +#include "base/kaldi-error.h" +#include "base/kaldi-types.h" +#include "base/io-funcs.h" +#include "base/kaldi-math.h" +#include "base/timer.h" + +#endif // KALDI_BASE_KALDI_COMMON_H_ diff --git a/torchaudio/csrc/kaldi/base/kaldi-error.cc b/torchaudio/csrc/kaldi/base/kaldi-error.cc new file mode 100644 index 00000000000..12f972ee856 --- /dev/null +++ b/torchaudio/csrc/kaldi/base/kaldi-error.cc @@ -0,0 +1,249 @@ +// base/kaldi-error.cc + +// Copyright 2019 LAIX (Yi Sun) +// Copyright 2019 SmartAction LLC (kkm) +// Copyright 2016 Brno University of Technology (author: Karel Vesely) +// Copyright 2009-2011 Microsoft Corporation; Lukas Burget; Ondrej Glembek + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifdef HAVE_EXECINFO_H +#include // To get stack trace in error messages. +// If this #include fails there is an error in the Makefile, it does not +// support your platform well. Make sure HAVE_EXECINFO_H is undefined, +// and the code will compile. +#ifdef HAVE_CXXABI_H +#include // For name demangling. +// Useful to decode the stack trace, but only used if we have execinfo.h +#endif // HAVE_CXXABI_H +#endif // HAVE_EXECINFO_H + +#include "base/kaldi-common.h" +#include "base/kaldi-error.h" + +// KALDI_GIT_HEAD is useless currently in full repo +#if !defined(KALDI_VERSION) +#include "base/version.h" +#endif + +namespace kaldi { + +/***** GLOBAL VARIABLES FOR LOGGING *****/ + +int32 g_kaldi_verbose_level = 0; +static std::string program_name; +static LogHandler log_handler = NULL; + +void SetProgramName(const char *basename) { + // Using the 'static std::string' for the program name is mostly harmless, + // because (a) Kaldi logging is undefined before main(), and (b) no stdc++ + // string implementation has been found in the wild that would not be just + // an empty string when zero-initialized but not yet constructed. + program_name = basename; +} + +/***** HELPER FUNCTIONS *****/ + +// Trim filename to at most 1 trailing directory long. Given a filename like +// "/a/b/c/d/e/f.cc", return "e/f.cc". Support both '/' and '\' as the path +// separator. +static const char *GetShortFileName(const char *path) { + if (path == nullptr) + return ""; + + const char *prev = path, *last = path; + while ((path = std::strpbrk(path, "\\/")) != nullptr) { + ++path; + prev = last; + last = path; + } + return prev; +} + +/***** STACK TRACE *****/ + +namespace internal { +bool LocateSymbolRange(const std::string &trace_name, size_t *begin, + size_t *end) { + // Find the first '_' with leading ' ' or '('. + *begin = std::string::npos; + for (size_t i = 1; i < trace_name.size(); i++) { + if (trace_name[i] != '_') { + continue; + } + if (trace_name[i - 1] == ' ' || trace_name[i - 1] == '(') { + *begin = i; + break; + } + } + if (*begin == std::string::npos) { + return false; + } + *end = trace_name.find_first_of(" +", *begin); + return *end != std::string::npos; +} +} // namespace internal + +#ifdef HAVE_EXECINFO_H +static std::string Demangle(std::string trace_name) { +#ifndef HAVE_CXXABI_H + return trace_name; +#else // HAVE_CXXABI_H + // Try demangle the symbol. We are trying to support the following formats + // produced by different platforms: + // + // Linux: + // ./kaldi-error-test(_ZN5kaldi13UnitTestErrorEv+0xb) [0x804965d] + // + // Mac: + // 0 server 0x000000010f67614d _ZNK5kaldi13MessageLogger10LogMessageEv + 813 + // + // We want to extract the name e.g., '_ZN5kaldi13UnitTestErrorEv' and + // demangle it info a readable name like kaldi::UnitTextError. + size_t begin, end; + if (!internal::LocateSymbolRange(trace_name, &begin, &end)) { + return trace_name; + } + std::string symbol = trace_name.substr(begin, end - begin); + int status; + char *demangled_name = abi::__cxa_demangle(symbol.c_str(), 0, 0, &status); + if (status == 0 && demangled_name != nullptr) { + symbol = demangled_name; + free(demangled_name); + } + return trace_name.substr(0, begin) + symbol + + trace_name.substr(end, std::string::npos); +#endif // HAVE_CXXABI_H +} +#endif // HAVE_EXECINFO_H + +static std::string KaldiGetStackTrace() { + std::string ans; +#ifdef HAVE_EXECINFO_H + const size_t KALDI_MAX_TRACE_SIZE = 50; + const size_t KALDI_MAX_TRACE_PRINT = 50; // Must be even. + // Buffer for the trace. + void *trace[KALDI_MAX_TRACE_SIZE]; + // Get the trace. + size_t size = backtrace(trace, KALDI_MAX_TRACE_SIZE); + // Get the trace symbols. + char **trace_symbol = backtrace_symbols(trace, size); + if (trace_symbol == NULL) + return ans; + + // Compose a human-readable backtrace string. + ans += "[ Stack-Trace: ]\n"; + if (size <= KALDI_MAX_TRACE_PRINT) { + for (size_t i = 0; i < size; i++) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + } else { // Print out first+last (e.g.) 5. + for (size_t i = 0; i < KALDI_MAX_TRACE_PRINT / 2; i++) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + ans += ".\n.\n.\n"; + for (size_t i = size - KALDI_MAX_TRACE_PRINT / 2; i < size; i++) { + ans += Demangle(trace_symbol[i]) + "\n"; + } + if (size == KALDI_MAX_TRACE_SIZE) + ans += ".\n.\n.\n"; // Stack was too long, probably a bug. + } + + // We must free the array of pointers allocated by backtrace_symbols(), + // but not the strings themselves. + free(trace_symbol); +#endif // HAVE_EXECINFO_H + return ans; +} + +/***** KALDI LOGGING *****/ + +MessageLogger::MessageLogger(LogMessageEnvelope::Severity severity, + const char *func, const char *file, int32 line) { + // Obviously, we assume the strings survive the destruction of this object. + envelope_.severity = severity; + envelope_.func = func; + envelope_.file = GetShortFileName(file); // Points inside 'file'. + envelope_.line = line; +} + +void MessageLogger::LogMessage() const { + // Send to the logging handler if provided. + if (log_handler != NULL) { + log_handler(envelope_, GetMessage().c_str()); + return; + } + + // Otherwise, use the default Kaldi logging. + // Build the log-message header. + std::stringstream full_message; + if (envelope_.severity > LogMessageEnvelope::kInfo) { + full_message << "VLOG[" << envelope_.severity << "] ("; + } else { + switch (envelope_.severity) { + case LogMessageEnvelope::kInfo: + full_message << "LOG ("; + break; + case LogMessageEnvelope::kWarning: + full_message << "WARNING ("; + break; + case LogMessageEnvelope::kAssertFailed: + full_message << "ASSERTION_FAILED ("; + break; + case LogMessageEnvelope::kError: + default: // If not the ERROR, it still an error! + full_message << "ERROR ("; + break; + } + } + // Add other info from the envelope and the message text. + full_message << program_name.c_str() << "[" KALDI_VERSION "]" << ':' + << envelope_.func << "():" << envelope_.file << ':' + << envelope_.line << ") " << GetMessage().c_str(); + + // Add stack trace for errors and assertion failures, if available. + if (envelope_.severity < LogMessageEnvelope::kWarning) { + const std::string &stack_trace = KaldiGetStackTrace(); + if (!stack_trace.empty()) { + full_message << "\n\n" << stack_trace; + } + } + + // Print the complete message to stderr. + full_message << "\n"; + std::cerr << full_message.str(); +} + +/***** KALDI ASSERTS *****/ + +void KaldiAssertFailure_(const char *func, const char *file, int32 line, + const char *cond_str) { + MessageLogger::Log() = + MessageLogger(LogMessageEnvelope::kAssertFailed, func, file, line) + << "Assertion failed: (" << cond_str << ")"; + fflush(NULL); // Flush all pending buffers, abort() may not flush stderr. + std::abort(); +} + +/***** THIRD-PARTY LOG-HANDLER *****/ + +LogHandler SetLogHandler(LogHandler handler) { + LogHandler old_handler = log_handler; + log_handler = handler; + return old_handler; +} + +} // namespace kaldi diff --git a/torchaudio/csrc/kaldi/base/kaldi-error.h b/torchaudio/csrc/kaldi/base/kaldi-error.h new file mode 100644 index 00000000000..a9904a752cd --- /dev/null +++ b/torchaudio/csrc/kaldi/base/kaldi-error.h @@ -0,0 +1,231 @@ +// base/kaldi-error.h + +// Copyright 2019 LAIX (Yi Sun) +// Copyright 2019 SmartAction LLC (kkm) +// Copyright 2016 Brno University of Technology (author: Karel Vesely) +// Copyright 2009-2011 Microsoft Corporation; Ondrej Glembek; Lukas Burget; +// Saarland University + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_ERROR_H_ +#define KALDI_BASE_KALDI_ERROR_H_ 1 + +#include +#include +#include +#include +#include +#include + +#include "base/kaldi-types.h" +#include "base/kaldi-utils.h" +/* Important that this file does not depend on any other kaldi headers. */ + +#ifdef _MSC_VER +#define __func__ __FUNCTION__ +#endif + +namespace kaldi { + +/// \addtogroup error_group +/// @{ + +/***** PROGRAM NAME AND VERBOSITY LEVEL *****/ + +/// Called by ParseOptions to set base name (no directory) of the executing +/// program. The name is printed in logging code along with every message, +/// because in our scripts, we often mix together the stderr of many programs. +/// This function is very thread-unsafe. +void SetProgramName(const char *basename); + +/// This is set by util/parse-options.{h,cc} if you set --verbose=? option. +/// Do not use directly, prefer {Get,Set}VerboseLevel(). +extern int32 g_kaldi_verbose_level; + +/// Get verbosity level, usually set via command line '--verbose=' switch. +inline int32 GetVerboseLevel() { return g_kaldi_verbose_level; } + +/// This should be rarely used, except by programs using Kaldi as library; +/// command-line programs set the verbose level automatically from ParseOptions. +inline void SetVerboseLevel(int32 i) { g_kaldi_verbose_level = i; } + +/***** KALDI LOGGING *****/ + +/// Log message severity and source location info. +struct LogMessageEnvelope { + /// Message severity. In addition to these levels, positive values (1 to 6) + /// specify verbose logging level. Verbose messages are produced only when + /// SetVerboseLevel() has been called to set logging level to at least the + /// corresponding value. + enum Severity { + kAssertFailed = -3, //!< Assertion failure. abort() will be called. + kError = -2, //!< Fatal error. KaldiFatalError will be thrown. + kWarning = -1, //!< Indicates a recoverable but abnormal condition. + kInfo = 0, //!< Informational message. + }; + int severity; //!< A Severity value, or positive verbosity level. + const char *func; //!< Name of the function invoking the logging. + const char *file; //!< Source file name with up to 1 leading directory. + int32 line; // MessageLogger &operator<<(const T &val) { + ss_ << val; + return *this; + } + + // When assigned a MessageLogger, log its contents. + struct Log final { + void operator=(const MessageLogger &logger) { logger.LogMessage(); } + }; + + // When assigned a MessageLogger, log its contents and then throw + // a KaldiFatalError. + struct LogAndThrow final { + [[noreturn]] void operator=(const MessageLogger &logger) { + logger.LogMessage(); + throw KaldiFatalError(logger.GetMessage()); + } + }; + +private: + std::string GetMessage() const { return ss_.str(); } + void LogMessage() const; + + LogMessageEnvelope envelope_; + std::ostringstream ss_; +}; + +// Logging macros. +#define KALDI_ERR \ + ::kaldi::MessageLogger::LogAndThrow() = ::kaldi::MessageLogger( \ + ::kaldi::LogMessageEnvelope::kError, __func__, __FILE__, __LINE__) +#define KALDI_WARN \ + ::kaldi::MessageLogger::Log() = ::kaldi::MessageLogger( \ + ::kaldi::LogMessageEnvelope::kWarning, __func__, __FILE__, __LINE__) +#define KALDI_LOG \ + ::kaldi::MessageLogger::Log() = ::kaldi::MessageLogger( \ + ::kaldi::LogMessageEnvelope::kInfo, __func__, __FILE__, __LINE__) +#define KALDI_VLOG(v) \ + if ((v) <= ::kaldi::GetVerboseLevel()) \ + ::kaldi::MessageLogger::Log() = \ + ::kaldi::MessageLogger((::kaldi::LogMessageEnvelope::Severity)(v), \ + __func__, __FILE__, __LINE__) + +/***** KALDI ASSERTS *****/ + +[[noreturn]] void KaldiAssertFailure_(const char *func, const char *file, + int32 line, const char *cond_str); + +// Note on KALDI_ASSERT and KALDI_PARANOID_ASSERT: +// +// A single block {} around if /else does not work, because it causes +// syntax error (unmatched else block) in the following code: +// +// if (condition) +// KALDI_ASSERT(condition2); +// else +// SomethingElse(); +// +// do {} while(0) -- note there is no semicolon at the end! -- works nicely, +// and compilers will be able to optimize the loop away (as the condition +// is always false). +// +// Also see KALDI_COMPILE_TIME_ASSERT, defined in base/kaldi-utils.h, and +// KALDI_ASSERT_IS_INTEGER_TYPE and KALDI_ASSERT_IS_FLOATING_TYPE, also defined +// there. +#ifndef NDEBUG +#define KALDI_ASSERT(cond) \ + do { \ + if (cond) \ + (void)0; \ + else \ + ::kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); \ + } while (0) +#else +#define KALDI_ASSERT(cond) (void)0 +#endif + +// Some more expensive asserts only checked if this defined. +#ifdef KALDI_PARANOID +#define KALDI_PARANOID_ASSERT(cond) \ + do { \ + if (cond) \ + (void)0; \ + else \ + ::kaldi::KaldiAssertFailure_(__func__, __FILE__, __LINE__, #cond); \ + } while (0) +#else +#define KALDI_PARANOID_ASSERT(cond) (void)0 +#endif + +/***** THIRD-PARTY LOG-HANDLER *****/ + +/// Type of third-party logging function. +typedef void (*LogHandler)(const LogMessageEnvelope &envelope, + const char *message); + +/// Set logging handler. If called with a non-NULL function pointer, the +/// function pointed by it is called to send messages to a caller-provided log. +/// If called with a NULL pointer, restores default Kaldi error logging to +/// stderr. This function is obviously not thread safe; the log handler must be. +/// Returns a previously set logging handler pointer, or NULL. +LogHandler SetLogHandler(LogHandler); + +/// @} end "addtogroup error_group" + +// Functions within internal is exported for testing only, do not use. +namespace internal { +bool LocateSymbolRange(const std::string &trace_name, size_t *begin, + size_t *end); +} // namespace internal +} // namespace kaldi + +#endif // KALDI_BASE_KALDI_ERROR_H_ diff --git a/torchaudio/csrc/kaldi/base/kaldi-math.cc b/torchaudio/csrc/kaldi/base/kaldi-math.cc new file mode 100644 index 00000000000..484c80d44ee --- /dev/null +++ b/torchaudio/csrc/kaldi/base/kaldi-math.cc @@ -0,0 +1,162 @@ +// base/kaldi-math.cc + +// Copyright 2009-2011 Microsoft Corporation; Yanmin Qian; +// Saarland University; Jan Silovsky + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include "base/kaldi-math.h" +#ifndef _MSC_VER +#include +#include +#endif +#include +#include + +namespace kaldi { +// These routines are tested in matrix/matrix-test.cc + +int32 RoundUpToNearestPowerOfTwo(int32 n) { + KALDI_ASSERT(n > 0); + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n+1; +} + +static std::mutex _RandMutex; + +int Rand(struct RandomState* state) { +#if !defined(_POSIX_THREAD_SAFE_FUNCTIONS) + // On Windows and Cygwin, just call Rand() + return rand(); +#else + if (state) { + return rand_r(&(state->seed)); + } else { + std::lock_guard lock(_RandMutex); + return rand(); + } +#endif +} + +RandomState::RandomState() { + // we initialize it as Rand() + 27437 instead of just Rand(), because on some + // systems, e.g. at the very least Mac OSX Yosemite and later, it seems to be + // the case that rand_r when initialized with rand() will give you the exact + // same sequence of numbers that rand() will give if you keep calling rand() + // after that initial call. This can cause problems with repeated sequences. + // For example if you initialize two RandomState structs one after the other + // without calling rand() in between, they would give you the same sequence + // offset by one (if we didn't have the "+ 27437" in the code). 27437 is just + // a randomly chosen prime number. + seed = Rand() + 27437; +} + +bool WithProb(BaseFloat prob, struct RandomState* state) { + KALDI_ASSERT(prob >= 0 && prob <= 1.1); // prob should be <= 1.0, + // but we allow slightly larger values that could arise from roundoff in + // previous calculations. + KALDI_COMPILE_TIME_ASSERT(RAND_MAX > 128 * 128); + if (prob == 0) return false; + else if (prob == 1.0) return true; + else if (prob * RAND_MAX < 128.0) { + // prob is very small but nonzero, and the "main algorithm" + // wouldn't work that well. So: with probability 1/128, we + // return WithProb (prob * 128), else return false. + if (Rand(state) < RAND_MAX / 128) { // with probability 128... + // Note: we know that prob * 128.0 < 1.0, because + // we asserted RAND_MAX > 128 * 128. + return WithProb(prob * 128.0); + } else { + return false; + } + } else { + return (Rand(state) < ((RAND_MAX + static_cast(1.0)) * prob)); + } +} + +int32 RandInt(int32 min_val, int32 max_val, struct RandomState* state) { + // This is not exact. + KALDI_ASSERT(max_val >= min_val); + if (max_val == min_val) return min_val; + +#ifdef _MSC_VER + // RAND_MAX is quite small on Windows -> may need to handle larger numbers. + if (RAND_MAX > (max_val-min_val)*8) { + // *8 to avoid large inaccuracies in probability, from the modulus... + return min_val + + ((unsigned int)Rand(state) % (unsigned int)(max_val+1-min_val)); + } else { + if ((unsigned int)(RAND_MAX*RAND_MAX) > + (unsigned int)((max_val+1-min_val)*8)) { + // *8 to avoid inaccuracies in probability, from the modulus... + return min_val + ( (unsigned int)( (Rand(state)+RAND_MAX*Rand(state))) + % (unsigned int)(max_val+1-min_val)); + } else { + KALDI_ERR << "rand_int failed because we do not support such large " + "random numbers. (Extend this function)."; + } + } +#else + return min_val + + (static_cast(Rand(state)) % static_cast(max_val+1-min_val)); +#endif +} + +// Returns poisson-distributed random number. +// Take care: this takes time proportional +// to lambda. Faster algorithms exist but are more complex. +int32 RandPoisson(float lambda, struct RandomState* state) { + // Knuth's algorithm. + KALDI_ASSERT(lambda >= 0); + float L = expf(-lambda), p = 1.0; + int32 k = 0; + do { + k++; + float u = RandUniform(state); + p *= u; + } while (p > L); + return k-1; +} + +void RandGauss2(float *a, float *b, RandomState *state) { + KALDI_ASSERT(a); + KALDI_ASSERT(b); + float u1 = RandUniform(state); + float u2 = RandUniform(state); + u1 = sqrtf(-2.0f * logf(u1)); + u2 = 2.0f * M_PI * u2; + *a = u1 * cosf(u2); + *b = u1 * sinf(u2); +} + +void RandGauss2(double *a, double *b, RandomState *state) { + KALDI_ASSERT(a); + KALDI_ASSERT(b); + float a_float, b_float; + // Just because we're using doubles doesn't mean we need super-high-quality + // random numbers, so we just use the floating-point version internally. + RandGauss2(&a_float, &b_float, state); + *a = a_float; + *b = b_float; +} + + +} // end namespace kaldi diff --git a/torchaudio/csrc/kaldi/base/kaldi-math.h b/torchaudio/csrc/kaldi/base/kaldi-math.h new file mode 100644 index 00000000000..93c265ee96e --- /dev/null +++ b/torchaudio/csrc/kaldi/base/kaldi-math.h @@ -0,0 +1,363 @@ +// base/kaldi-math.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Yanmin Qian; +// Jan Silovsky; Saarland University +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_MATH_H_ +#define KALDI_BASE_KALDI_MATH_H_ 1 + +#ifdef _MSC_VER +#include +#endif + +#include +#include +#include + +#include "base/kaldi-types.h" +#include "base/kaldi-common.h" + + +#ifndef DBL_EPSILON +#define DBL_EPSILON 2.2204460492503131e-16 +#endif +#ifndef FLT_EPSILON +#define FLT_EPSILON 1.19209290e-7f +#endif + +#ifndef M_PI +#define M_PI 3.1415926535897932384626433832795 +#endif + +#ifndef M_SQRT2 +#define M_SQRT2 1.4142135623730950488016887 +#endif + +#ifndef M_2PI +#define M_2PI 6.283185307179586476925286766559005 +#endif + +#ifndef M_SQRT1_2 +#define M_SQRT1_2 0.7071067811865475244008443621048490 +#endif + +#ifndef M_LOG_2PI +#define M_LOG_2PI 1.8378770664093454835606594728112 +#endif + +#ifndef M_LN2 +#define M_LN2 0.693147180559945309417232121458 +#endif + +#ifndef M_LN10 +#define M_LN10 2.302585092994045684017991454684 +#endif + + +#define KALDI_ISNAN std::isnan +#define KALDI_ISINF std::isinf +#define KALDI_ISFINITE(x) std::isfinite(x) + +#if !defined(KALDI_SQR) +# define KALDI_SQR(x) ((x) * (x)) +#endif + +namespace kaldi { + +#if !defined(_MSC_VER) || (_MSC_VER >= 1900) +inline double Exp(double x) { return exp(x); } +#ifndef KALDI_NO_EXPF +inline float Exp(float x) { return expf(x); } +#else +inline float Exp(float x) { return exp(static_cast(x)); } +#endif // KALDI_NO_EXPF +#else +inline double Exp(double x) { return exp(x); } +#if !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64) +// Microsoft CL v18.0 buggy 64-bit implementation of +// expf() incorrectly returns -inf for exp(-inf). +inline float Exp(float x) { return exp(static_cast(x)); } +#else +inline float Exp(float x) { return expf(x); } +#endif // !defined(__INTEL_COMPILER) && _MSC_VER == 1800 && defined(_M_X64) +#endif // !defined(_MSC_VER) || (_MSC_VER >= 1900) + +inline double Log(double x) { return log(x); } +inline float Log(float x) { return logf(x); } + +#if !defined(_MSC_VER) || (_MSC_VER >= 1700) +inline double Log1p(double x) { return log1p(x); } +inline float Log1p(float x) { return log1pf(x); } +#else +inline double Log1p(double x) { + const double cutoff = 1.0e-08; + if (x < cutoff) + return x - 0.5 * x * x; + else + return Log(1.0 + x); +} + +inline float Log1p(float x) { + const float cutoff = 1.0e-07; + if (x < cutoff) + return x - 0.5 * x * x; + else + return Log(1.0 + x); +} +#endif + +static const double kMinLogDiffDouble = Log(DBL_EPSILON); // negative! +static const float kMinLogDiffFloat = Log(FLT_EPSILON); // negative! + +// -infinity +const float kLogZeroFloat = -std::numeric_limits::infinity(); +const double kLogZeroDouble = -std::numeric_limits::infinity(); +const BaseFloat kLogZeroBaseFloat = -std::numeric_limits::infinity(); + +// Returns a random integer between 0 and RAND_MAX, inclusive +int Rand(struct RandomState* state = NULL); + +// State for thread-safe random number generator +struct RandomState { + RandomState(); + unsigned seed; +}; + +// Returns a random integer between first and last inclusive. +int32 RandInt(int32 first, int32 last, struct RandomState* state = NULL); + +// Returns true with probability "prob", +bool WithProb(BaseFloat prob, struct RandomState* state = NULL); +// with 0 <= prob <= 1 [we check this]. +// Internally calls Rand(). This function is carefully implemented so +// that it should work even if prob is very small. + +/// Returns a random number strictly between 0 and 1. +inline float RandUniform(struct RandomState* state = NULL) { + return static_cast((Rand(state) + 1.0) / (RAND_MAX+2.0)); +} + +inline float RandGauss(struct RandomState* state = NULL) { + return static_cast(sqrtf (-2 * Log(RandUniform(state))) + * cosf(2*M_PI*RandUniform(state))); +} + +// Returns poisson-distributed random number. Uses Knuth's algorithm. +// Take care: this takes time proportional +// to lambda. Faster algorithms exist but are more complex. +int32 RandPoisson(float lambda, struct RandomState* state = NULL); + +// Returns a pair of gaussian random numbers. Uses Box-Muller transform +void RandGauss2(float *a, float *b, RandomState *state = NULL); +void RandGauss2(double *a, double *b, RandomState *state = NULL); + +// Also see Vector::RandCategorical(). + +// This is a randomized pruning mechanism that preserves expectations, +// that we typically use to prune posteriors. +template +inline Float RandPrune(Float post, BaseFloat prune_thresh, + struct RandomState* state = NULL) { + KALDI_ASSERT(prune_thresh >= 0.0); + if (post == 0.0 || std::abs(post) >= prune_thresh) + return post; + return (post >= 0 ? 1.0 : -1.0) * + (RandUniform(state) <= fabs(post)/prune_thresh ? prune_thresh : 0.0); +} + +// returns log(exp(x) + exp(y)). +inline double LogAdd(double x, double y) { + double diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffDouble) { + double res; + res = x + Log1p(Exp(diff)); + return res; + } else { + return x; // return the larger one. + } +} + + +// returns log(exp(x) + exp(y)). +inline float LogAdd(float x, float y) { + float diff; + + if (x < y) { + diff = x - y; + x = y; + } else { + diff = y - x; + } + // diff is negative. x is now the larger one. + + if (diff >= kMinLogDiffFloat) { + float res; + res = x + Log1p(Exp(diff)); + return res; + } else { + return x; // return the larger one. + } +} + + +// returns log(exp(x) - exp(y)). +inline double LogSub(double x, double y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return kLogZeroDouble; + else + KALDI_ERR << "Cannot subtract a larger from a smaller number."; + } + + double diff = y - x; // Will be negative. + double res = x + Log(1.0 - Exp(diff)); + + // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision + if (KALDI_ISNAN(res)) + return kLogZeroDouble; + return res; +} + + +// returns log(exp(x) - exp(y)). +inline float LogSub(float x, float y) { + if (y >= x) { // Throws exception if y>=x. + if (y == x) + return kLogZeroDouble; + else + KALDI_ERR << "Cannot subtract a larger from a smaller number."; + } + + float diff = y - x; // Will be negative. + float res = x + Log(1.0f - Exp(diff)); + + // res might be NAN if diff ~0.0, and 1.0-exp(diff) == 0 to machine precision + if (KALDI_ISNAN(res)) + return kLogZeroFloat; + return res; +} + +/// return abs(a - b) <= relative_tolerance * (abs(a)+abs(b)). +static inline bool ApproxEqual(float a, float b, + float relative_tolerance = 0.001) { + // a==b handles infinities. + if (a == b) return true; + float diff = std::abs(a-b); + if (diff == std::numeric_limits::infinity() + || diff != diff) return false; // diff is +inf or nan. + return (diff <= relative_tolerance*(std::abs(a)+std::abs(b))); +} + +/// assert abs(a - b) <= relative_tolerance * (abs(a)+abs(b)) +static inline void AssertEqual(float a, float b, + float relative_tolerance = 0.001) { + // a==b handles infinities. + KALDI_ASSERT(ApproxEqual(a, b, relative_tolerance)); +} + + +// RoundUpToNearestPowerOfTwo does the obvious thing. It crashes if n <= 0. +int32 RoundUpToNearestPowerOfTwo(int32 n); + +/// Returns a / b, rounding towards negative infinity in all cases. +static inline int32 DivideRoundingDown(int32 a, int32 b) { + KALDI_ASSERT(b != 0); + if (a * b >= 0) + return a / b; + else if (a < 0) + return (a - b + 1) / b; + else + return (a - b - 1) / b; +} + +template I Gcd(I m, I n) { + if (m == 0 || n == 0) { + if (m == 0 && n == 0) { // gcd not defined, as all integers are divisors. + KALDI_ERR << "Undefined GCD since m = 0, n = 0."; + } + return (m == 0 ? (n > 0 ? n : -n) : ( m > 0 ? m : -m)); + // return absolute value of whichever is nonzero + } + // could use compile-time assertion + // but involves messing with complex template stuff. + KALDI_ASSERT(std::numeric_limits::is_integer); + while (1) { + m %= n; + if (m == 0) return (n > 0 ? n : -n); + n %= m; + if (n == 0) return (m > 0 ? m : -m); + } +} + +/// Returns the least common multiple of two integers. Will +/// crash unless the inputs are positive. +template I Lcm(I m, I n) { + KALDI_ASSERT(m > 0 && n > 0); + I gcd = Gcd(m, n); + return gcd * (m/gcd) * (n/gcd); +} + + +template void Factorize(I m, std::vector *factors) { + // Splits a number into its prime factors, in sorted order from + // least to greatest, with duplication. A very inefficient + // algorithm, which is mainly intended for use in the + // mixed-radix FFT computation (where we assume most factors + // are small). + KALDI_ASSERT(factors != NULL); + KALDI_ASSERT(m >= 1); // Doesn't work for zero or negative numbers. + factors->clear(); + I small_factors[10] = { 2, 3, 5, 7, 11, 13, 17, 19, 23, 29 }; + + // First try small factors. + for (I i = 0; i < 10; i++) { + if (m == 1) return; // We're done. + while (m % small_factors[i] == 0) { + m /= small_factors[i]; + factors->push_back(small_factors[i]); + } + } + // Next try all odd numbers starting from 31. + for (I j = 31;; j += 2) { + if (m == 1) return; + while (m % j == 0) { + m /= j; + factors->push_back(j); + } + } +} + +inline double Hypot(double x, double y) { return hypot(x, y); } +inline float Hypot(float x, float y) { return hypotf(x, y); } + + + + +} // namespace kaldi + + +#endif // KALDI_BASE_KALDI_MATH_H_ diff --git a/torchaudio/csrc/kaldi/base/kaldi-types.h b/torchaudio/csrc/kaldi/base/kaldi-types.h new file mode 100644 index 00000000000..c15b288b271 --- /dev/null +++ b/torchaudio/csrc/kaldi/base/kaldi-types.h @@ -0,0 +1,75 @@ +// base/kaldi-types.h + +// Copyright 2009-2011 Microsoft Corporation; Saarland University; +// Jan Silovsky; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_TYPES_H_ +#define KALDI_BASE_KALDI_TYPES_H_ 1 + +namespace kaldi { +// TYPEDEFS .................................................................. +#if (KALDI_DOUBLEPRECISION != 0) +typedef double BaseFloat; +#else +typedef float BaseFloat; +#endif +} + +#ifdef _MSC_VER +#include +#define ssize_t SSIZE_T +#endif + +// we can do this a different way if some platform +// we find in the future lacks stdint.h +#include + +// for discussion on what to do if you need compile kaldi +// without OpenFST, see the bottom of this this file +/* +#include + +namespace kaldi { + using ::int16; + using ::int32; + using ::int64; + using ::uint16; + using ::uint32; + using ::uint64; + typedef float float32; + typedef double double64; +} // end namespace kaldi +*/ + +// In a theoretical case you decide compile Kaldi without the OpenFST +// comment the previous namespace statement and uncomment the following +namespace kaldi { + typedef int8_t int8; + typedef int16_t int16; + typedef int32_t int32; + typedef int64_t int64; + + typedef uint8_t uint8; + typedef uint16_t uint16; + typedef uint32_t uint32; + typedef uint64_t uint64; + typedef float float32; + typedef double double64; +} // end namespace kaldi + +#endif // KALDI_BASE_KALDI_TYPES_H_ diff --git a/torchaudio/csrc/kaldi/base/kaldi-utils.h b/torchaudio/csrc/kaldi/base/kaldi-utils.h new file mode 100644 index 00000000000..1c96882510a --- /dev/null +++ b/torchaudio/csrc/kaldi/base/kaldi-utils.h @@ -0,0 +1,155 @@ +// base/kaldi-utils.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; +// Saarland University; Karel Vesely; Yanmin Qian + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_BASE_KALDI_UTILS_H_ +#define KALDI_BASE_KALDI_UTILS_H_ 1 + +#if defined(_MSC_VER) +# define WIN32_LEAN_AND_MEAN +# define NOMINMAX +# include +#endif + +#ifdef _MSC_VER +#include +#define unlink _unlink +#else +#include +#endif + +#include +#include + +#if defined(_MSC_VER) +#pragma warning(disable: 4244 4056 4305 4800 4267 4996 4756 4661) +#if _MSC_VER < 1400 +#define __restrict__ +#else +#define __restrict__ __restrict +#endif +#endif + +#if defined(_MSC_VER) +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (*(pp_orig) = _aligned_malloc(size, align)) +# define KALDI_MEMALIGN_FREE(x) _aligned_free(x) +#elif defined(__CYGWIN__) +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (*(pp_orig) = aligned_alloc(align, size)) +# define KALDI_MEMALIGN_FREE(x) free(x) +#else +# define KALDI_MEMALIGN(align, size, pp_orig) \ + (!posix_memalign(pp_orig, align, size) ? *(pp_orig) : NULL) +# define KALDI_MEMALIGN_FREE(x) free(x) +#endif + +#ifdef __ICC +#pragma warning(disable: 383) // ICPC remark we don't want. +#pragma warning(disable: 810) // ICPC remark we don't want. +#pragma warning(disable: 981) // ICPC remark we don't want. +#pragma warning(disable: 1418) // ICPC remark we don't want. +#pragma warning(disable: 444) // ICPC remark we don't want. +#pragma warning(disable: 869) // ICPC remark we don't want. +#pragma warning(disable: 1287) // ICPC remark we don't want. +#pragma warning(disable: 279) // ICPC remark we don't want. +#pragma warning(disable: 981) // ICPC remark we don't want. +#endif + + +namespace kaldi { + + +// CharToString prints the character in a human-readable form, for debugging. +std::string CharToString(const char &c); + + +inline int MachineIsLittleEndian() { + int check = 1; + return (*reinterpret_cast(&check) != 0); +} + +// This function kaldi::Sleep() provides a portable way +// to sleep for a possibly fractional +// number of seconds. On Windows it's only accurate to microseconds. +void Sleep(float seconds); +} + +#define KALDI_SWAP8(a) do { \ + int t = (reinterpret_cast(&a))[0];\ + (reinterpret_cast(&a))[0]=(reinterpret_cast(&a))[7];\ + (reinterpret_cast(&a))[7]=t;\ + t = (reinterpret_cast(&a))[1];\ + (reinterpret_cast(&a))[1]=(reinterpret_cast(&a))[6];\ + (reinterpret_cast(&a))[6]=t;\ + t = (reinterpret_cast(&a))[2];\ + (reinterpret_cast(&a))[2]=(reinterpret_cast(&a))[5];\ + (reinterpret_cast(&a))[5]=t;\ + t = (reinterpret_cast(&a))[3];\ + (reinterpret_cast(&a))[3]=(reinterpret_cast(&a))[4];\ + (reinterpret_cast(&a))[4]=t;} while (0) +#define KALDI_SWAP4(a) do { \ + int t = (reinterpret_cast(&a))[0];\ + (reinterpret_cast(&a))[0]=(reinterpret_cast(&a))[3];\ + (reinterpret_cast(&a))[3]=t;\ + t = (reinterpret_cast(&a))[1];\ + (reinterpret_cast(&a))[1]=(reinterpret_cast(&a))[2];\ + (reinterpret_cast(&a))[2]=t;} while (0) +#define KALDI_SWAP2(a) do { \ + int t = (reinterpret_cast(&a))[0];\ + (reinterpret_cast(&a))[0]=(reinterpret_cast(&a))[1];\ + (reinterpret_cast(&a))[1]=t;} while (0) + + +// Makes copy constructor and operator= private. +#define KALDI_DISALLOW_COPY_AND_ASSIGN(type) \ + type(const type&); \ + void operator = (const type&) + +template class KaldiCompileTimeAssert { }; +template<> class KaldiCompileTimeAssert { + public: + static inline void Check() { } +}; + +#define KALDI_COMPILE_TIME_ASSERT(b) KaldiCompileTimeAssert<(b)>::Check() + +#define KALDI_ASSERT_IS_INTEGER_TYPE(I) \ + KaldiCompileTimeAssert::is_specialized \ + && std::numeric_limits::is_integer>::Check() + +#define KALDI_ASSERT_IS_FLOATING_TYPE(F) \ + KaldiCompileTimeAssert::is_specialized \ + && !std::numeric_limits::is_integer>::Check() + +#if defined(_MSC_VER) +#define KALDI_STRCASECMP _stricmp +#elif defined(__CYGWIN__) +#include +#define KALDI_STRCASECMP strcasecmp +#else +#define KALDI_STRCASECMP strcasecmp +#endif +#ifdef _MSC_VER +# define KALDI_STRTOLL(cur_cstr, end_cstr) _strtoi64(cur_cstr, end_cstr, 10); +#else +# define KALDI_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); +#endif + +#endif // KALDI_BASE_KALDI_UTILS_H_ diff --git a/torchaudio/csrc/kaldi/base/timer.h b/torchaudio/csrc/kaldi/base/timer.h new file mode 100644 index 00000000000..0e033766362 --- /dev/null +++ b/torchaudio/csrc/kaldi/base/timer.h @@ -0,0 +1,115 @@ +// base/timer.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_BASE_TIMER_H_ +#define KALDI_BASE_TIMER_H_ + +#include "base/kaldi-utils.h" +#include "base/kaldi-error.h" + + +#if defined(_MSC_VER) || defined(MINGW) + +namespace kaldi { +class Timer { + public: + Timer() { Reset(); } + + // You can initialize with bool to control whether or not you want the time to + // be set when the object is created. + explicit Timer(bool set_timer) { if (set_timer) Reset(); } + + void Reset() { + QueryPerformanceCounter(&time_start_); + } + double Elapsed() const { + LARGE_INTEGER time_end; + LARGE_INTEGER freq; + QueryPerformanceCounter(&time_end); + + if (QueryPerformanceFrequency(&freq) == 0) { + // Hardware does not support this. + return 0.0; + } + return (static_cast(time_end.QuadPart) - + static_cast(time_start_.QuadPart)) / + (static_cast(freq.QuadPart)); + } + private: + LARGE_INTEGER time_start_; +}; + + +#else +#include +#include + +namespace kaldi { +class Timer { + public: + Timer() { Reset(); } + + // You can initialize with bool to control whether or not you want the time to + // be set when the object is created. + explicit Timer(bool set_timer) { if (set_timer) Reset(); } + + void Reset() { gettimeofday(&this->time_start_, &time_zone_); } + + /// Returns time in seconds. + double Elapsed() const { + struct timeval time_end; + struct timezone time_zone; + gettimeofday(&time_end, &time_zone); + double t1, t2; + t1 = static_cast(time_start_.tv_sec) + + static_cast(time_start_.tv_usec)/(1000*1000); + t2 = static_cast(time_end.tv_sec) + + static_cast(time_end.tv_usec)/(1000*1000); + return t2-t1; + } + + private: + struct timeval time_start_; + struct timezone time_zone_; +}; + +#endif + +class Profiler { + public: + // Caution: the 'const char' should always be a string constant; for speed, + // internally the profiling code uses the address of it as a lookup key. + Profiler(const char *function_name): name_(function_name) { } + ~Profiler(); + private: + Timer tim_; + const char *name_; +}; + +// To add timing info for a function, you just put +// KALDI_PROFILE; +// at the beginning of the function. Caution: this doesn't +// include the class name. +#define KALDI_PROFILE Profiler _profiler(__func__) + + + +} // namespace kaldi + + +#endif // KALDI_BASE_TIMER_H_ diff --git a/torchaudio/csrc/kaldi/base/version.h b/torchaudio/csrc/kaldi/base/version.h new file mode 100644 index 00000000000..451b972260c --- /dev/null +++ b/torchaudio/csrc/kaldi/base/version.h @@ -0,0 +1,4 @@ +// This file was automatically created by ./get_version.sh. +// It is only included by ./kaldi-error.cc. +#define KALDI_VERSION "5.5.839~3-0c6a" +#define KALDI_GIT_HEAD "0c6a3dcf0ca2cbd2b7a180183ca7665465d5d042" diff --git a/torchaudio/csrc/kaldi/feat/feature-common-inl.h b/torchaudio/csrc/kaldi/feat/feature-common-inl.h new file mode 100644 index 00000000000..26127a4dc4d --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/feature-common-inl.h @@ -0,0 +1,99 @@ +// feat/feature-common-inl.h + +// Copyright 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_COMMON_INL_H_ +#define KALDI_FEAT_FEATURE_COMMON_INL_H_ + +#include "feat/resample.h" +// Do not include this file directly. It is included by feat/feature-common.h + +namespace kaldi { + +template +void OfflineFeatureTpl::ComputeFeatures( + const VectorBase &wave, + BaseFloat sample_freq, + BaseFloat vtln_warp, + Matrix *output) { + KALDI_ASSERT(output != NULL); + BaseFloat new_sample_freq = computer_.GetFrameOptions().samp_freq; + if (sample_freq == new_sample_freq) { + Compute(wave, vtln_warp, output); + } else { + if (new_sample_freq < sample_freq && + ! computer_.GetFrameOptions().allow_downsample) + KALDI_ERR << "Waveform and config sample Frequency mismatch: " + << sample_freq << " .vs " << new_sample_freq + << " (use --allow-downsample=true to allow " + << " downsampling the waveform)."; + else if (new_sample_freq > sample_freq && + ! computer_.GetFrameOptions().allow_upsample) + KALDI_ERR << "Waveform and config sample Frequency mismatch: " + << sample_freq << " .vs " << new_sample_freq + << " (use --allow-upsample=true option to allow " + << " upsampling the waveform)."; + // Resample the waveform. + Vector resampled_wave(wave); + ResampleWaveform(sample_freq, wave, + new_sample_freq, &resampled_wave); + Compute(resampled_wave, vtln_warp, output); + } +} + +template +void OfflineFeatureTpl::Compute( + const VectorBase &wave, + BaseFloat vtln_warp, + Matrix *output) { + KALDI_ASSERT(output != NULL); + int32 rows_out = NumFrames(wave.Dim(), computer_.GetFrameOptions()), + cols_out = computer_.Dim(); + if (rows_out == 0) { + output->Resize(0, 0); + return; + } + output->Resize(rows_out, cols_out); + Vector window; // windowed waveform. + bool use_raw_log_energy = computer_.NeedRawLogEnergy(); + for (int32 r = 0; r < rows_out; r++) { // r is frame index. + BaseFloat raw_log_energy = 0.0; + ExtractWindow(0, wave, r, computer_.GetFrameOptions(), + feature_window_function_, &window, + (use_raw_log_energy ? &raw_log_energy : NULL)); + + SubVector output_row(*output, r); + computer_.Compute(raw_log_energy, vtln_warp, &window, &output_row); + } +} + +template +void OfflineFeatureTpl::Compute( + const VectorBase &wave, + BaseFloat vtln_warp, + Matrix *output) const { + OfflineFeatureTpl temp(*this); + // call the non-const version of Compute() on a temporary copy of this object. + // This is a workaround for const-ness that may sometimes be useful in + // multi-threaded code, although it's not optimally efficient. + temp.Compute(wave, vtln_warp, output); +} + +} // end namespace kaldi + +#endif diff --git a/torchaudio/csrc/kaldi/feat/feature-common.h b/torchaudio/csrc/kaldi/feat/feature-common.h new file mode 100644 index 00000000000..3c2fbd37381 --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/feature-common.h @@ -0,0 +1,176 @@ +// feat/feature-common.h + +// Copyright 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABILITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_COMMON_H_ +#define KALDI_FEAT_FEATURE_COMMON_H_ + +#include +#include +#include "feat/feature-window.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + + +/// This class is only added for documentation, it is not intended to ever be +/// used. +struct ExampleFeatureComputerOptions { + FrameExtractionOptions frame_opts; + // .. more would go here. +}; + +/// This class is only added for documentation, it is not intended to ever be +/// used. It documents the interface of the *Computer classes which wrap the +/// low-level feature extraction. The template argument F of OfflineFeatureTpl must +/// follow this interface. This interface is intended for features such as +/// MFCCs and PLPs which can be computed frame by frame. +class ExampleFeatureComputer { + public: + typedef ExampleFeatureComputerOptions Options; + + /// Returns a reference to the frame-extraction options class, which + /// will be part of our own options class. + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + /// Returns the feature dimension + int32 Dim() const; + + /// Returns true if this function may inspect the raw log-energy of the signal + /// (before windowing and pre-emphasis); it's safe to always return true, but + /// setting it to false enables an optimization. + bool NeedRawLogEnergy() const { return true; } + + /// constructor from options class; it should not store a reference or pointer + /// to the options class but should copy it. + explicit ExampleFeatureComputer(const ExampleFeatureComputerOptions &opts): + opts_(opts) { } + + /// Copy constructor; all of these classes must have one. + ExampleFeatureComputer(const ExampleFeatureComputer &other); + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. + */ + void Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature); + + private: + // disallow assignment. + ExampleFeatureComputer &operator = (const ExampleFeatureComputer &in); + Options opts_; +}; + + +/// This templated class is intended for offline feature extraction, i.e. where +/// you have access to the entire signal at the start. It exists mainly to be +/// drop-in replacement for the old (pre-2016) classes Mfcc, Plp and so on, for +/// use in the offline case. In April 2016 we reorganized the online +/// feature-computation code for greater modularity and to have correct support +/// for the snip-edges=false option. +template +class OfflineFeatureTpl { + public: + typedef typename F::Options Options; + + // Note: feature_window_function_ is the windowing function, which initialized + // using the options class, that we cache at this level. + OfflineFeatureTpl(const Options &opts): + computer_(opts), + feature_window_function_(computer_.GetFrameOptions()) { } + + // Internal (and back-compatibility) interface for computing features, which + // requires that the user has already checked that the sampling frequency + // of the waveform is equal to the sampling frequency specified in + // the frame-extraction options. + void Compute(const VectorBase &wave, + BaseFloat vtln_warp, + Matrix *output); + + // This const version of Compute() is a wrapper that + // calls the non-const version on a temporary object. + // It's less efficient than the non-const version. + void Compute(const VectorBase &wave, + BaseFloat vtln_warp, + Matrix *output) const; + + /** + Computes the features for one file (one sequence of features). + This is the newer interface where you specify the sample frequency + of the input waveform. + @param [in] wave The input waveform + @param [in] sample_freq The sampling frequency with which + 'wave' was sampled. + if sample_freq is higher than the frequency + specified in the config, we will downsample + the waveform, but if lower, it's an error. + @param [in] vtln_warp The VTLN warping factor (will normally + be 1.0) + @param [out] output The matrix of features, where the row-index + is the frame index. + */ + void ComputeFeatures(const VectorBase &wave, + BaseFloat sample_freq, + BaseFloat vtln_warp, + Matrix *output); + + int32 Dim() const { return computer_.Dim(); } + + // Copy constructor. + OfflineFeatureTpl(const OfflineFeatureTpl &other): + computer_(other.computer_), + feature_window_function_(other.feature_window_function_) { } + private: + // Disallow assignment. + OfflineFeatureTpl &operator =(const OfflineFeatureTpl &other); + + F computer_; + FeatureWindowFunction feature_window_function_; +}; + +/// @} End of "addtogroup feat" +} // namespace kaldi + + +#include "feat/feature-common-inl.h" + +#endif // KALDI_FEAT_FEATURE_COMMON_H_ diff --git a/torchaudio/csrc/kaldi/feat/feature-fbank.h b/torchaudio/csrc/kaldi/feat/feature-fbank.h new file mode 100644 index 00000000000..f57d185a41c --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/feature-fbank.h @@ -0,0 +1,149 @@ +// feat/feature-fbank.h + +// Copyright 2009-2012 Karel Vesely +// 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_FBANK_H_ +#define KALDI_FEAT_FEATURE_FBANK_H_ + +#include +#include + +#include "feat/feature-common.h" +#include "feat/feature-functions.h" +#include "feat/feature-window.h" +#include "feat/mel-computations.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + +/// FbankOptions contains basic options for computing filterbank features. +/// It only includes things that can be done in a "stateless" way, i.e. +/// it does not include energy max-normalization. +/// It does not include delta computation. +struct FbankOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + bool use_energy; // append an extra dimension with energy to the filter banks + BaseFloat energy_floor; + bool raw_energy; // If true, compute energy before preemphasis and windowing + bool htk_compat; // If true, put energy last (if using energy) + bool use_log_fbank; // if true (default), produce log-filterbank, else linear + bool use_power; // if true (default), use power in filterbank analysis, else magnitude. + + FbankOptions(): mel_opts(23), + // defaults the #mel-banks to 23 for the FBANK computations. + // this seems to be common for 16khz-sampled data, + // but for 8khz-sampled data, 15 may be better. + use_energy(false), + energy_floor(0.0), + raw_energy(true), + htk_compat(false), + use_log_fbank(true), + use_power(true) {} + + void Register(OptionsItf *opts) { + frame_opts.Register(opts); + mel_opts.Register(opts); + opts->Register("use-energy", &use_energy, + "Add an extra dimension with energy to the FBANK output."); + opts->Register("energy-floor", &energy_floor, + "Floor on energy (absolute, not relative) in FBANK computation. " + "Only makes a difference if --use-energy=true; only necessary if " + "--dither=0.0. Suggested values: 0.1 or 1.0"); + opts->Register("raw-energy", &raw_energy, + "If true, compute energy before preemphasis and windowing"); + opts->Register("htk-compat", &htk_compat, "If true, put energy last. " + "Warning: not sufficient to get HTK compatible features (need " + "to change other parameters)."); + opts->Register("use-log-fbank", &use_log_fbank, + "If true, produce log-filterbank, else produce linear."); + opts->Register("use-power", &use_power, + "If true, use power, else use magnitude."); + } +}; + + +/// Class for computing mel-filterbank features; see \ref feat_mfcc for more +/// information. +class FbankComputer { + public: + typedef FbankOptions Options; + + explicit FbankComputer(const FbankOptions &opts); + FbankComputer(const FbankComputer &other); + + int32 Dim() const { + return opts_.mel_opts.num_bins + (opts_.use_energy ? 1 : 0); + } + + bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. + */ + void Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature); + + ~FbankComputer(); + + private: + const MelBanks *GetMelBanks(BaseFloat vtln_warp); + + + FbankOptions opts_; + BaseFloat log_energy_floor_; + std::map mel_banks_; // BaseFloat is VTLN coefficient. + SplitRadixRealFft *srfft_; + // Disallow assignment. + FbankComputer &operator =(const FbankComputer &other); +}; + +typedef OfflineFeatureTpl Fbank; + +/// @} End of "addtogroup feat" +} // namespace kaldi + + +#endif // KALDI_FEAT_FEATURE_FBANK_H_ diff --git a/torchaudio/csrc/kaldi/feat/feature-functions.cc b/torchaudio/csrc/kaldi/feat/feature-functions.cc new file mode 100644 index 00000000000..76500ccf87a --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/feature-functions.cc @@ -0,0 +1,362 @@ +// feat/feature-functions.cc + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Microsoft Corporation +// 2013 Johns Hopkins University (author: Daniel Povey) +// 2014 IMSL, PKU-HKUST (author: Wei Shi) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include "feat/feature-functions.h" +#include "matrix/matrix-functions.h" + + +namespace kaldi { + +void ComputePowerSpectrum(VectorBase *waveform) { + int32 dim = waveform->Dim(); + + // no, letting it be non-power-of-two for now. + // KALDI_ASSERT(dim > 0 && (dim & (dim-1) == 0)); // make sure a power of two.. actually my FFT code + // does not require this (dan) but this is better in case we use different code [dan]. + + // RealFft(waveform, true); // true == forward (not inverse) FFT; makes no difference here, + // as we just want power spectrum. + + // now we have in waveform, first half of complex spectrum + // it's stored as [real0, realN/2, real1, im1, real2, im2, ...] + int32 half_dim = dim/2; + BaseFloat first_energy = (*waveform)(0) * (*waveform)(0), + last_energy = (*waveform)(1) * (*waveform)(1); // handle this special case + for (int32 i = 1; i < half_dim; i++) { + BaseFloat real = (*waveform)(i*2), im = (*waveform)(i*2 + 1); + (*waveform)(i) = real*real + im*im; + } + (*waveform)(0) = first_energy; + (*waveform)(half_dim) = last_energy; // Will actually never be used, and anyway + // if the signal has been bandlimited sensibly this should be zero. +} + + +DeltaFeatures::DeltaFeatures(const DeltaFeaturesOptions &opts): opts_(opts) { + KALDI_ASSERT(opts.order >= 0 && opts.order < 1000); // just make sure we don't get binary junk. + // opts will normally be 2 or 3. + KALDI_ASSERT(opts.window > 0 && opts.window < 1000); // again, basic sanity check. + // normally the window size will be two. + + scales_.resize(opts.order+1); + scales_[0].Resize(1); + scales_[0](0) = 1.0; // trivial window for 0th order delta [i.e. baseline feats] + + for (int32 i = 1; i <= opts.order; i++) { + Vector &prev_scales = scales_[i-1], + &cur_scales = scales_[i]; + int32 window = opts.window; // this code is designed to still + // work if instead we later make it an array and do opts.window[i-1], + // or something like that. "window" is a parameter specifying delta-window + // width which is actually 2*window + 1. + KALDI_ASSERT(window != 0); + int32 prev_offset = (static_cast(prev_scales.Dim()-1))/2, + cur_offset = prev_offset + window; + cur_scales.Resize(prev_scales.Dim() + 2*window); // also zeros it. + + BaseFloat normalizer = 0.0; + for (int32 j = -window; j <= window; j++) { + normalizer += j*j; + for (int32 k = -prev_offset; k <= prev_offset; k++) { + cur_scales(j+k+cur_offset) += + static_cast(j) * prev_scales(k+prev_offset); + } + } + cur_scales.Scale(1.0 / normalizer); + } +} + +void DeltaFeatures::Process(const MatrixBase &input_feats, + int32 frame, + VectorBase *output_frame) const { + KALDI_ASSERT(frame < input_feats.NumRows()); + int32 num_frames = input_feats.NumRows(), + feat_dim = input_feats.NumCols(); + KALDI_ASSERT(static_cast(output_frame->Dim()) == feat_dim * (opts_.order+1)); + output_frame->SetZero(); + for (int32 i = 0; i <= opts_.order; i++) { + const Vector &scales = scales_[i]; + int32 max_offset = (scales.Dim() - 1) / 2; + SubVector output(*output_frame, i*feat_dim, feat_dim); + for (int32 j = -max_offset; j <= max_offset; j++) { + // if asked to read + int32 offset_frame = frame + j; + if (offset_frame < 0) offset_frame = 0; + else if (offset_frame >= num_frames) + offset_frame = num_frames - 1; + BaseFloat scale = scales(j + max_offset); + if (scale != 0.0) + output.AddVec(scale, input_feats.Row(offset_frame)); + } + } +} + +ShiftedDeltaFeatures::ShiftedDeltaFeatures( + const ShiftedDeltaFeaturesOptions &opts): opts_(opts) { + KALDI_ASSERT(opts.window > 0 && opts.window < 1000); + + // Default window is 1. + int32 window = opts.window; + KALDI_ASSERT(window != 0); + scales_.Resize(1 + 2*window); // also zeros it. + BaseFloat normalizer = 0.0; + for (int32 j = -window; j <= window; j++) { + normalizer += j*j; + scales_(j + window) += static_cast(j); + } + scales_.Scale(1.0 / normalizer); +} + +void ShiftedDeltaFeatures::Process(const MatrixBase &input_feats, + int32 frame, + SubVector *output_frame) const { + KALDI_ASSERT(frame < input_feats.NumRows()); + int32 num_frames = input_feats.NumRows(), + feat_dim = input_feats.NumCols(); + KALDI_ASSERT(static_cast(output_frame->Dim()) + == feat_dim * (opts_.num_blocks + 1)); + output_frame->SetZero(); + + // The original features + SubVector output(*output_frame, 0, feat_dim); + output.AddVec(1.0, input_feats.Row(frame)); + + // Concatenate the delta-blocks. Each block is block_shift + // (usually 3) frames apart. + for (int32 i = 0; i < opts_.num_blocks; i++) { + int32 max_offset = (scales_.Dim() - 1) / 2; + SubVector output(*output_frame, (i + 1) * feat_dim, feat_dim); + for (int32 j = -max_offset; j <= max_offset; j++) { + int32 offset_frame = frame + j + i * opts_.block_shift; + if (offset_frame < 0) offset_frame = 0; + else if (offset_frame >= num_frames) + offset_frame = num_frames - 1; + BaseFloat scale = scales_(j + max_offset); + if (scale != 0.0) + output.AddVec(scale, input_feats.Row(offset_frame)); + } + } +} + +void ComputeDeltas(const DeltaFeaturesOptions &delta_opts, + const MatrixBase &input_features, + Matrix *output_features) { + output_features->Resize(input_features.NumRows(), + input_features.NumCols() + *(delta_opts.order + 1)); + DeltaFeatures delta(delta_opts); + for (int32 r = 0; r < static_cast(input_features.NumRows()); r++) { + SubVector row(*output_features, r); + delta.Process(input_features, r, &row); + } +} + +void ComputeShiftedDeltas(const ShiftedDeltaFeaturesOptions &delta_opts, + const MatrixBase &input_features, + Matrix *output_features) { + output_features->Resize(input_features.NumRows(), + input_features.NumCols() + * (delta_opts.num_blocks + 1)); + ShiftedDeltaFeatures delta(delta_opts); + + for (int32 r = 0; r < static_cast(input_features.NumRows()); r++) { + SubVector row(*output_features, r); + delta.Process(input_features, r, &row); + } +} + + +void InitIdftBases(int32 n_bases, int32 dimension, Matrix *mat_out) { + BaseFloat angle = M_PI / static_cast(dimension - 1); + BaseFloat scale = 1.0f / (2.0 * static_cast(dimension - 1)); + mat_out->Resize(n_bases, dimension); + for (int32 i = 0; i < n_bases; i++) { + (*mat_out)(i, 0) = 1.0 * scale; + BaseFloat i_fl = static_cast(i); + for (int32 j = 1; j < dimension - 1; j++) { + BaseFloat j_fl = static_cast(j); + (*mat_out)(i, j) = 2.0 * scale * cos(angle * i_fl * j_fl); + } + + (*mat_out)(i, dimension -1) + = scale * cos(angle * i_fl * static_cast(dimension-1)); + } +} + +void SpliceFrames(const MatrixBase &input_features, + int32 left_context, + int32 right_context, + Matrix *output_features) { + int32 T = input_features.NumRows(), D = input_features.NumCols(); + if (T == 0 || D == 0) + KALDI_ERR << "SpliceFrames: empty input"; + KALDI_ASSERT(left_context >= 0 && right_context >= 0); + int32 N = 1 + left_context + right_context; + output_features->Resize(T, D*N); + for (int32 t = 0; t < T; t++) { + SubVector dst_row(*output_features, t); + for (int32 j = 0; j < N; j++) { + int32 t2 = t + j - left_context; + if (t2 < 0) t2 = 0; + if (t2 >= T) t2 = T-1; + SubVector dst(dst_row, j*D, D), + src(input_features, t2); + dst.CopyFromVec(src); + } + } +} + +void ReverseFrames(const MatrixBase &input_features, + Matrix *output_features) { + int32 T = input_features.NumRows(), D = input_features.NumCols(); + if (T == 0 || D == 0) + KALDI_ERR << "ReverseFrames: empty input"; + output_features->Resize(T, D); + for (int32 t = 0; t < T; t++) { + SubVector dst_row(*output_features, t); + SubVector src_row(input_features, T-1-t); + dst_row.CopyFromVec(src_row); + } +} + + +void SlidingWindowCmnOptions::Check() const { + KALDI_ASSERT(cmn_window > 0); + if (center) + KALDI_ASSERT(min_window > 0 && min_window <= cmn_window); + // else ignored so value doesn't matter. +} + +// Internal version of SlidingWindowCmn with double-precision arguments. +void SlidingWindowCmnInternal(const SlidingWindowCmnOptions &opts, + const MatrixBase &input, + MatrixBase *output) { + opts.Check(); + int32 num_frames = input.NumRows(), dim = input.NumCols(), + last_window_start = -1, last_window_end = -1, + warning_count = 0; + Vector cur_sum(dim), cur_sumsq(dim); + + for (int32 t = 0; t < num_frames; t++) { + int32 window_start, window_end; // note: window_end will be one + // past the end of the window we use for normalization. + if (opts.center) { + window_start = t - (opts.cmn_window / 2); + window_end = window_start + opts.cmn_window; + } else { + window_start = t - opts.cmn_window; + window_end = t + 1; + } + if (window_start < 0) { // shift window right if starts <0. + window_end -= window_start; + window_start = 0; // or: window_start -= window_start + } + if (!opts.center) { + if (window_end > t) + window_end = std::max(t + 1, opts.min_window); + } + if (window_end > num_frames) { + window_start -= (window_end - num_frames); + window_end = num_frames; + if (window_start < 0) window_start = 0; + } + if (last_window_start == -1) { + SubMatrix input_part(input, + window_start, window_end - window_start, + 0, dim); + cur_sum.AddRowSumMat(1.0, input_part , 0.0); + if (opts.normalize_variance) + cur_sumsq.AddDiagMat2(1.0, input_part, kTrans, 0.0); + } else { + if (window_start > last_window_start) { + KALDI_ASSERT(window_start == last_window_start + 1); + SubVector frame_to_remove(input, last_window_start); + cur_sum.AddVec(-1.0, frame_to_remove); + if (opts.normalize_variance) + cur_sumsq.AddVec2(-1.0, frame_to_remove); + } + if (window_end > last_window_end) { + KALDI_ASSERT(window_end == last_window_end + 1); + SubVector frame_to_add(input, last_window_end); + cur_sum.AddVec(1.0, frame_to_add); + if (opts.normalize_variance) + cur_sumsq.AddVec2(1.0, frame_to_add); + } + } + int32 window_frames = window_end - window_start; + last_window_start = window_start; + last_window_end = window_end; + + KALDI_ASSERT(window_frames > 0); + SubVector input_frame(input, t), + output_frame(*output, t); + output_frame.CopyFromVec(input_frame); + output_frame.AddVec(-1.0 / window_frames, cur_sum); + + if (opts.normalize_variance) { + if (window_frames == 1) { + output_frame.Set(0.0); + } else { + Vector variance(cur_sumsq); + variance.Scale(1.0 / window_frames); + variance.AddVec2(-1.0 / (window_frames * window_frames), cur_sum); + // now "variance" is the variance of the features in the window, + // around their own mean. + int32 num_floored; + variance.ApplyFloor(1.0e-10, &num_floored); + if (num_floored > 0 && num_frames > 1) { + if (opts.max_warnings == warning_count) { + KALDI_WARN << "Suppressing the remaining variance flooring " + << "warnings. Run program with --max-warnings=-1 to " + << "see all warnings."; + } + // If opts.max_warnings is a negative number, we won't restrict the + // number of times that the warning is printed out. + else if (opts.max_warnings < 0 + || opts.max_warnings > warning_count) { + KALDI_WARN << "Flooring when normalizing variance, floored " + << num_floored << " elements; num-frames was " + << window_frames; + } + warning_count++; + } + variance.ApplyPow(-0.5); // get inverse standard deviation. + output_frame.MulElements(variance); + } + } + } +} + + +void SlidingWindowCmn(const SlidingWindowCmnOptions &opts, + const MatrixBase &input, + MatrixBase *output) { + KALDI_ASSERT(SameDim(input, *output) && input.NumRows() > 0); + Matrix input_dbl(input), output_dbl(input.NumRows(), input.NumCols()); + // call double-precision version + SlidingWindowCmnInternal(opts, input_dbl, &output_dbl); + output->CopyFromMat(output_dbl); +} + + + +} // namespace kaldi diff --git a/torchaudio/csrc/kaldi/feat/feature-functions.h b/torchaudio/csrc/kaldi/feat/feature-functions.h new file mode 100644 index 00000000000..52454f3048b --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/feature-functions.h @@ -0,0 +1,204 @@ +// feat/feature-functions.h + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Microsoft Corporation +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_FEAT_FEATURE_FUNCTIONS_H_ +#define KALDI_FEAT_FEATURE_FUNCTIONS_H_ + +#include +#include + +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + +// ComputePowerSpectrum converts a complex FFT (as produced by the FFT +// functions in matrix/matrix-functions.h), and converts it into +// a power spectrum. If the complex FFT is a vector of size n (representing +// half the complex FFT of a real signal of size n, as described there), +// this function computes in the first (n/2) + 1 elements of it, the +// energies of the fft bins from zero to the Nyquist frequency. Contents of the +// remaining (n/2) - 1 elements are undefined at output. +void ComputePowerSpectrum(VectorBase *complex_fft); + + +struct DeltaFeaturesOptions { + int32 order; + int32 window; // e.g. 2; controls window size (window size is 2*window + 1) + // the behavior at the edges is to replicate the first or last frame. + // this is not configurable. + + DeltaFeaturesOptions(int32 order = 2, int32 window = 2): + order(order), window(window) { } + void Register(OptionsItf *opts) { + opts->Register("delta-order", &order, "Order of delta computation"); + opts->Register("delta-window", &window, + "Parameter controlling window for delta computation (actual window" + " size for each delta order is 1 + 2*delta-window-size)"); + } +}; + +class DeltaFeatures { + public: + // This class provides a low-level function to compute delta features. + // The function takes as input a matrix of features and a frame index + // that it should compute the deltas on. It puts its output in an object + // of type VectorBase, of size (original-feature-dimension) * (opts.order+1). + // This is not the most efficient way to do the computation, but it's + // state-free and thus easier to understand + + explicit DeltaFeatures(const DeltaFeaturesOptions &opts); + + void Process(const MatrixBase &input_feats, + int32 frame, + VectorBase *output_frame) const; + private: + DeltaFeaturesOptions opts_; + std::vector > scales_; // a scaling window for each + // of the orders, including zero: multiply the features for each + // dimension by this window. +}; + +struct ShiftedDeltaFeaturesOptions { + int32 window, // The time delay and advance + num_blocks, + block_shift; // Distance between consecutive blocks + + ShiftedDeltaFeaturesOptions(): + window(1), num_blocks(7), block_shift(3) { } + void Register(OptionsItf *opts) { + opts->Register("delta-window", &window, "Size of delta advance and delay."); + opts->Register("num-blocks", &num_blocks, "Number of delta blocks in advance" + " of each frame to be concatenated"); + opts->Register("block-shift", &block_shift, "Distance between each block"); + } +}; + +class ShiftedDeltaFeatures { + public: + // This class provides a low-level function to compute shifted + // delta cesptra (SDC). + // The function takes as input a matrix of features and a frame index + // that it should compute the deltas on. It puts its output in an object + // of type VectorBase, of size original-feature-dimension + (1 * num_blocks). + + explicit ShiftedDeltaFeatures(const ShiftedDeltaFeaturesOptions &opts); + + void Process(const MatrixBase &input_feats, + int32 frame, + SubVector *output_frame) const; + private: + ShiftedDeltaFeaturesOptions opts_; + Vector scales_; // a scaling window for each + +}; + +// ComputeDeltas is a convenience function that computes deltas on a feature +// file. If you want to deal with features coming in bit by bit you would have +// to use the DeltaFeatures class directly, and do the computation frame by +// frame. Later we will have to come up with a nice mechanism to do this for +// features coming in. +void ComputeDeltas(const DeltaFeaturesOptions &delta_opts, + const MatrixBase &input_features, + Matrix *output_features); + +// ComputeShiftedDeltas computes deltas from a feature file by applying +// ShiftedDeltaFeatures over the frames. This function is provided for +// convenience, however, ShiftedDeltaFeatures can be used directly. +void ComputeShiftedDeltas(const ShiftedDeltaFeaturesOptions &delta_opts, + const MatrixBase &input_features, + Matrix *output_features); + +// SpliceFrames will normally be used together with LDA. +// It splices frames together to make a window. At the +// start and end of an utterance, it duplicates the first +// and last frames. +// Will throw if input features are empty. +// left_context and right_context must be nonnegative. +// these both represent a number of frames (e.g. 4, 4 is +// a good choice). +void SpliceFrames(const MatrixBase &input_features, + int32 left_context, + int32 right_context, + Matrix *output_features); + +// ReverseFrames reverses the frames in time (used for backwards decoding) +void ReverseFrames(const MatrixBase &input_features, + Matrix *output_features); + + +void InitIdftBases(int32 n_bases, int32 dimension, Matrix *mat_out); + + +// This is used for speaker-id. Also see OnlineCmnOptions in ../online2/, which +// is online CMN with no latency, for online speech recognition. +struct SlidingWindowCmnOptions { + int32 cmn_window; + int32 min_window; + int32 max_warnings; + bool normalize_variance; + bool center; + + SlidingWindowCmnOptions(): + cmn_window(600), + min_window(100), + max_warnings(5), + normalize_variance(false), + center(false) { } + + void Register(OptionsItf *opts) { + opts->Register("cmn-window", &cmn_window, "Window in frames for running " + "average CMN computation"); + opts->Register("min-cmn-window", &min_window, "Minimum CMN window " + "used at start of decoding (adds latency only at start). " + "Only applicable if center == false, ignored if center==true"); + opts->Register("max-warnings", &max_warnings, "Maximum warnings to report " + "per utterance. 0 to disable, -1 to show all."); + opts->Register("norm-vars", &normalize_variance, "If true, normalize " + "variance to one."); // naming this as in apply-cmvn.cc + opts->Register("center", ¢er, "If true, use a window centered on the " + "current frame (to the extent possible, modulo end effects). " + "If false, window is to the left."); + } + void Check() const; +}; + + +/// Applies sliding-window cepstral mean and/or variance normalization. See the +/// strings registering the options in the options class for information on how +/// this works and what the options are. input and output must have the same +/// dimension. +void SlidingWindowCmn(const SlidingWindowCmnOptions &opts, + const MatrixBase &input, + MatrixBase *output); + + +/// @} End of "addtogroup feat" +} // namespace kaldi + + + +#endif // KALDI_FEAT_FEATURE_FUNCTIONS_H_ diff --git a/torchaudio/csrc/kaldi/feat/feature-mfcc.h b/torchaudio/csrc/kaldi/feat/feature-mfcc.h new file mode 100644 index 00000000000..dbfb9d60364 --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/feature-mfcc.h @@ -0,0 +1,154 @@ +// feat/feature-mfcc.h + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Saarland University +// 2014-2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_MFCC_H_ +#define KALDI_FEAT_FEATURE_MFCC_H_ + +#include +#include + +#include "feat/feature-common.h" +#include "feat/feature-functions.h" +#include "feat/feature-window.h" +#include "feat/mel-computations.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + +/// MfccOptions contains basic options for computing MFCC features. +struct MfccOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + int32 num_ceps; // e.g. 13: num cepstral coeffs, counting zero. + bool use_energy; // use energy; else C0 + BaseFloat energy_floor; // 0 by default; set to a value like 1.0 or 0.1 if + // you disable dithering. + bool raw_energy; // If true, compute energy before preemphasis and windowing + BaseFloat cepstral_lifter; // Scaling factor on cepstra for HTK compatibility. + // if 0.0, no liftering is done. + bool htk_compat; // if true, put energy/C0 last and introduce a factor of + // sqrt(2) on C0 to be the same as HTK. + + MfccOptions() : mel_opts(23), + // defaults the #mel-banks to 23 for the MFCC computations. + // this seems to be common for 16khz-sampled data, + // but for 8khz-sampled data, 15 may be better. + num_ceps(13), + use_energy(true), + energy_floor(0.0), + raw_energy(true), + cepstral_lifter(22.0), + htk_compat(false) {} + + void Register(OptionsItf *opts) { + frame_opts.Register(opts); + mel_opts.Register(opts); + opts->Register("num-ceps", &num_ceps, + "Number of cepstra in MFCC computation (including C0)"); + opts->Register("use-energy", &use_energy, + "Use energy (not C0) in MFCC computation"); + opts->Register("energy-floor", &energy_floor, + "Floor on energy (absolute, not relative) in MFCC computation. " + "Only makes a difference if --use-energy=true; only necessary if " + "--dither=0.0. Suggested values: 0.1 or 1.0"); + opts->Register("raw-energy", &raw_energy, + "If true, compute energy before preemphasis and windowing"); + opts->Register("cepstral-lifter", &cepstral_lifter, + "Constant that controls scaling of MFCCs"); + opts->Register("htk-compat", &htk_compat, + "If true, put energy or C0 last and use a factor of sqrt(2) on " + "C0. Warning: not sufficient to get HTK compatible features " + "(need to change other parameters)."); + } +}; + + + +// This is the new-style interface to the MFCC computation. +class MfccComputer { + public: + typedef MfccOptions Options; + explicit MfccComputer(const MfccOptions &opts); + MfccComputer(const MfccComputer &other); + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + int32 Dim() const { return opts_.num_ceps; } + + bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. + */ + void Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature); + + ~MfccComputer(); + private: + // disallow assignment. + MfccComputer &operator = (const MfccComputer &in); + + protected: + const MelBanks *GetMelBanks(BaseFloat vtln_warp); + + MfccOptions opts_; + Vector lifter_coeffs_; + Matrix dct_matrix_; // matrix we left-multiply by to perform DCT. + BaseFloat log_energy_floor_; + std::map mel_banks_; // BaseFloat is VTLN coefficient. + SplitRadixRealFft *srfft_; + + // note: mel_energies_ is specific to the frame we're processing, it's + // just a temporary workspace. + Vector mel_energies_; +}; + +typedef OfflineFeatureTpl Mfcc; + + +/// @} End of "addtogroup feat" +} // namespace kaldi + + +#endif // KALDI_FEAT_FEATURE_MFCC_H_ diff --git a/torchaudio/csrc/kaldi/feat/feature-plp.h b/torchaudio/csrc/kaldi/feat/feature-plp.h new file mode 100644 index 00000000000..4f156ca1e88 --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/feature-plp.h @@ -0,0 +1,176 @@ +// feat/feature-plp.h + +// Copyright 2009-2011 Petr Motlicek; Karel Vesely + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_PLP_H_ +#define KALDI_FEAT_FEATURE_PLP_H_ + +#include +#include + +#include "feat/feature-common.h" +#include "feat/feature-functions.h" +#include "feat/feature-window.h" +#include "feat/mel-computations.h" +#include "itf/options-itf.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + + + +/// PlpOptions contains basic options for computing PLP features. +/// It only includes things that can be done in a "stateless" way, i.e. +/// it does not include energy max-normalization. +/// It does not include delta computation. +struct PlpOptions { + FrameExtractionOptions frame_opts; + MelBanksOptions mel_opts; + int32 lpc_order; + int32 num_ceps; // num cepstra including zero + bool use_energy; // use energy; else C0 + BaseFloat energy_floor; + bool raw_energy; // If true, compute energy before preemphasis and windowing + BaseFloat compress_factor; + int32 cepstral_lifter; + BaseFloat cepstral_scale; + + bool htk_compat; // if true, put energy/C0 last and introduce a factor of + // sqrt(2) on C0 to be the same as HTK. + + PlpOptions() : mel_opts(23), + // default number of mel-banks for the PLP computation; this + // seems to be common for 16kHz-sampled data. For 8kHz-sampled + // data, 15 may be better. + lpc_order(12), + num_ceps(13), + use_energy(true), + energy_floor(0.0), + raw_energy(true), + compress_factor(0.33333), + cepstral_lifter(22), + cepstral_scale(1.0), + htk_compat(false) {} + + void Register(OptionsItf *opts) { + frame_opts.Register(opts); + mel_opts.Register(opts); + opts->Register("lpc-order", &lpc_order, + "Order of LPC analysis in PLP computation"); + opts->Register("num-ceps", &num_ceps, + "Number of cepstra in PLP computation (including C0)"); + opts->Register("use-energy", &use_energy, + "Use energy (not C0) for zeroth PLP feature"); + opts->Register("energy-floor", &energy_floor, + "Floor on energy (absolute, not relative) in PLP computation. " + "Only makes a difference if --use-energy=true; only necessary if " + "--dither=0.0. Suggested values: 0.1 or 1.0"); + opts->Register("raw-energy", &raw_energy, + "If true, compute energy before preemphasis and windowing"); + opts->Register("compress-factor", &compress_factor, + "Compression factor in PLP computation"); + opts->Register("cepstral-lifter", &cepstral_lifter, + "Constant that controls scaling of PLPs"); + opts->Register("cepstral-scale", &cepstral_scale, + "Scaling constant in PLP computation"); + opts->Register("htk-compat", &htk_compat, + "If true, put energy or C0 last. Warning: not sufficient " + "to get HTK compatible features (need to change other " + "parameters)."); + } +}; + + +/// This is the new-style interface to the PLP computation. +class PlpComputer { + public: + typedef PlpOptions Options; + explicit PlpComputer(const PlpOptions &opts); + PlpComputer(const PlpComputer &other); + + const FrameExtractionOptions &GetFrameOptions() const { + return opts_.frame_opts; + } + + int32 Dim() const { return opts_.num_ceps; } + + bool NeedRawLogEnergy() const { return opts_.use_energy && opts_.raw_energy; } + + /** + Function that computes one frame of features from + one frame of signal. + + @param [in] signal_raw_log_energy The log-energy of the frame of the signal + prior to windowing and pre-emphasis, or + log(numeric_limits::min()), whichever is greater. Must be + ignored by this function if this class returns false from + this->NeedsRawLogEnergy(). + @param [in] vtln_warp The VTLN warping factor that the user wants + to be applied when computing features for this utterance. Will + normally be 1.0, meaning no warping is to be done. The value will + be ignored for feature types that don't support VLTN, such as + spectrogram features. + @param [in] signal_frame One frame of the signal, + as extracted using the function ExtractWindow() using the options + returned by this->GetFrameOptions(). The function will use the + vector as a workspace, which is why it's a non-const pointer. + @param [out] feature Pointer to a vector of size this->Dim(), to which + the computed feature will be written. + */ + void Compute(BaseFloat signal_raw_log_energy, + BaseFloat vtln_warp, + VectorBase *signal_frame, + VectorBase *feature); + + ~PlpComputer(); + private: + + const MelBanks *GetMelBanks(BaseFloat vtln_warp); + + const Vector *GetEqualLoudness(BaseFloat vtln_warp); + + PlpOptions opts_; + Vector lifter_coeffs_; + Matrix idft_bases_; + BaseFloat log_energy_floor_; + std::map mel_banks_; // BaseFloat is VTLN coefficient. + std::map* > equal_loudness_; + SplitRadixRealFft *srfft_; + + // temporary vector used inside Compute; size is opts_.mel_opts.num_bins + 2 + Vector mel_energies_duplicated_; + // temporary vector used inside Compute; size is opts_.lpc_order + 1 + Vector autocorr_coeffs_; + // temporary vector used inside Compute; size is opts_.lpc_order + Vector lpc_coeffs_; + // temporary vector used inside Compute; size is opts_.lpc_order + Vector raw_cepstrum_; + + // Disallow assignment. + PlpComputer &operator =(const PlpComputer &other); +}; + +typedef OfflineFeatureTpl Plp; + +/// @} End of "addtogroup feat" + +} // namespace kaldi + + +#endif // KALDI_FEAT_FEATURE_PLP_H_ diff --git a/torchaudio/csrc/kaldi/feat/feature-window.h b/torchaudio/csrc/kaldi/feat/feature-window.h new file mode 100644 index 00000000000..e6d673937ac --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/feature-window.h @@ -0,0 +1,223 @@ +// feat/feature-window.h + +// Copyright 2009-2011 Karel Vesely; Petr Motlicek; Saarland University +// 2014-2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_FEATURE_WINDOW_H_ +#define KALDI_FEAT_FEATURE_WINDOW_H_ + +#include +#include + +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + +struct FrameExtractionOptions { + BaseFloat samp_freq; + BaseFloat frame_shift_ms; // in milliseconds. + BaseFloat frame_length_ms; // in milliseconds. + BaseFloat dither; // Amount of dithering, 0.0 means no dither. + BaseFloat preemph_coeff; // Preemphasis coefficient. + bool remove_dc_offset; // Subtract mean of wave before FFT. + std::string window_type; // e.g. Hamming window + // May be "hamming", "rectangular", "povey", "hanning", "sine", "blackman" + // "povey" is a window I made to be similar to Hamming but to go to zero at the + // edges, it's pow((0.5 - 0.5*cos(n/N*2*pi)), 0.85) + // I just don't think the Hamming window makes sense as a windowing function. + bool round_to_power_of_two; + BaseFloat blackman_coeff; + bool snip_edges; + bool allow_downsample; + bool allow_upsample; + int max_feature_vectors; + FrameExtractionOptions(): + samp_freq(16000), + frame_shift_ms(10.0), + frame_length_ms(25.0), + dither(1.0), + preemph_coeff(0.97), + remove_dc_offset(true), + window_type("povey"), + round_to_power_of_two(true), + blackman_coeff(0.42), + snip_edges(true), + allow_downsample(false), + allow_upsample(false), + max_feature_vectors(-1) + { } + + void Register(OptionsItf *opts) { + opts->Register("sample-frequency", &samp_freq, + "Waveform data sample frequency (must match the waveform file, " + "if specified there)"); + opts->Register("frame-length", &frame_length_ms, "Frame length in milliseconds"); + opts->Register("frame-shift", &frame_shift_ms, "Frame shift in milliseconds"); + opts->Register("preemphasis-coefficient", &preemph_coeff, + "Coefficient for use in signal preemphasis"); + opts->Register("remove-dc-offset", &remove_dc_offset, + "Subtract mean from waveform on each frame"); + opts->Register("dither", &dither, "Dithering constant (0.0 means no dither). " + "If you turn this off, you should set the --energy-floor " + "option, e.g. to 1.0 or 0.1"); + opts->Register("window-type", &window_type, "Type of window " + "(\"hamming\"|\"hanning\"|\"povey\"|\"rectangular\"" + "|\"sine\"|\"blackmann\")"); + opts->Register("blackman-coeff", &blackman_coeff, + "Constant coefficient for generalized Blackman window."); + opts->Register("round-to-power-of-two", &round_to_power_of_two, + "If true, round window size to power of two by zero-padding " + "input to FFT."); + opts->Register("snip-edges", &snip_edges, + "If true, end effects will be handled by outputting only frames that " + "completely fit in the file, and the number of frames depends on the " + "frame-length. If false, the number of frames depends only on the " + "frame-shift, and we reflect the data at the ends."); + opts->Register("allow-downsample", &allow_downsample, + "If true, allow the input waveform to have a higher frequency than " + "the specified --sample-frequency (and we'll downsample)."); + opts->Register("max-feature-vectors", &max_feature_vectors, + "Memory optimization. If larger than 0, periodically remove feature " + "vectors so that only this number of the latest feature vectors is " + "retained."); + opts->Register("allow-upsample", &allow_upsample, + "If true, allow the input waveform to have a lower frequency than " + "the specified --sample-frequency (and we'll upsample)."); + } + int32 WindowShift() const { + return static_cast(samp_freq * 0.001 * frame_shift_ms); + } + int32 WindowSize() const { + return static_cast(samp_freq * 0.001 * frame_length_ms); + } + int32 PaddedWindowSize() const { + return (round_to_power_of_two ? RoundUpToNearestPowerOfTwo(WindowSize()) : + WindowSize()); + } +}; + + +struct FeatureWindowFunction { + FeatureWindowFunction() {} + explicit FeatureWindowFunction(const FrameExtractionOptions &opts); + FeatureWindowFunction(const FeatureWindowFunction &other): + window(other.window) { } + Vector window; +}; + + +/** + This function returns the number of frames that we can extract from a wave + file with the given number of samples in it (assumed to have the same + sampling rate as specified in 'opts'). + + @param [in] num_samples The number of samples in the wave file. + @param [in] opts The frame-extraction options class + + @param [in] flush True if we are asserting that this number of samples is + 'all there is', false if we expecting more data to possibly come + in. This only makes a difference to the answer if opts.snips_edges + == false. For offline feature extraction you always want flush == + true. In an online-decoding context, once you know (or decide) that + no more data is coming in, you'd call it with flush == true at the + end to flush out any remaining data. +*/ +int32 NumFrames(int64 num_samples, + const FrameExtractionOptions &opts, + bool flush = true); + +/* + This function returns the index of the first sample of the frame indexed + 'frame'. If snip-edges=true, it just returns frame * opts.WindowShift(); if + snip-edges=false, the formula is a little more complicated and the result may + be negative. +*/ +int64 FirstSampleOfFrame(int32 frame, + const FrameExtractionOptions &opts); + + + +void Dither(VectorBase *waveform, BaseFloat dither_value); + +void Preemphasize(VectorBase *waveform, BaseFloat preemph_coeff); + +/** + This function does all the windowing steps after actually + extracting the windowed signal: depending on the + configuration, it does dithering, dc offset removal, + preemphasis, and multiplication by the windowing function. + @param [in] opts The options class to be used + @param [in] window_function The windowing function-- should have + been initialized using 'opts'. + @param [in,out] window A vector of size opts.WindowSize(). Note: + it will typically be a sub-vector of a larger vector of size + opts.PaddedWindowSize(), with the remaining samples zero, + as the FFT code is more efficient if it operates on data with + power-of-two size. + @param [out] log_energy_pre_window If non-NULL, then after dithering and + DC offset removal, this function will write to this pointer the log of + the total energy (i.e. sum-squared) of the frame. + */ +void ProcessWindow(const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + VectorBase *window, + BaseFloat *log_energy_pre_window = NULL); + + +/* + ExtractWindow() extracts a windowed frame of waveform (possibly with a + power-of-two, padded size, depending on the config), including all the + proessing done by ProcessWindow(). + + @param [in] sample_offset If 'wave' is not the entire waveform, but + part of it to the left has been discarded, then the + number of samples prior to 'wave' that we have + already discarded. Set this to zero if you are + processing the entire waveform in one piece, or + if you get 'no matching function' compilation + errors when updating the code. + @param [in] wave The waveform + @param [in] f The frame index to be extracted, with + 0 <= f < NumFrames(sample_offset + wave.Dim(), opts, true) + @param [in] opts The options class to be used + @param [in] window_function The windowing function, as derived from the + options class. + @param [out] window The windowed, possibly-padded waveform to be + extracted. Will be resized as needed. + @param [out] log_energy_pre_window If non-NULL, the log-energy of + the signal prior to pre-emphasis and multiplying by + the windowing function will be written to here. +*/ +void ExtractWindow(int64 sample_offset, + const VectorBase &wave, + int32 f, + const FrameExtractionOptions &opts, + const FeatureWindowFunction &window_function, + Vector *window, + BaseFloat *log_energy_pre_window = NULL); + + +/// @} End of "addtogroup feat" +} // namespace kaldi + + +#endif // KALDI_FEAT_FEATURE_WINDOW_H_ diff --git a/torchaudio/csrc/kaldi/feat/mel-computations.h b/torchaudio/csrc/kaldi/feat/mel-computations.h new file mode 100644 index 00000000000..0c1d41ca45c --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/mel-computations.h @@ -0,0 +1,171 @@ +// feat/mel-computations.h + +// Copyright 2009-2011 Phonexia s.r.o.; Microsoft Corporation +// 2016 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_MEL_COMPUTATIONS_H_ +#define KALDI_FEAT_MEL_COMPUTATIONS_H_ + +#include +#include +#include +#include +#include +#include + +#include "base/kaldi-common.h" +#include "util/common-utils.h" +#include "matrix/matrix-lib.h" + + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + +struct FrameExtractionOptions; // defined in feature-window.h + + +struct MelBanksOptions { + int32 num_bins; // e.g. 25; number of triangular bins + BaseFloat low_freq; // e.g. 20; lower frequency cutoff + BaseFloat high_freq; // an upper frequency cutoff; 0 -> no cutoff, negative + // ->added to the Nyquist frequency to get the cutoff. + BaseFloat vtln_low; // vtln lower cutoff of warping function. + BaseFloat vtln_high; // vtln upper cutoff of warping function: if negative, added + // to the Nyquist frequency to get the cutoff. + bool debug_mel; + // htk_mode is a "hidden" config, it does not show up on command line. + // Enables more exact compatibility with HTK, for testing purposes. Affects + // mel-energy flooring and reproduces a bug in HTK. + bool htk_mode; + explicit MelBanksOptions(int num_bins = 25) + : num_bins(num_bins), low_freq(20), high_freq(0), vtln_low(100), + vtln_high(-500), debug_mel(false), htk_mode(false) {} + + void Register(OptionsItf *opts) { + opts->Register("num-mel-bins", &num_bins, + "Number of triangular mel-frequency bins"); + opts->Register("low-freq", &low_freq, + "Low cutoff frequency for mel bins"); + opts->Register("high-freq", &high_freq, + "High cutoff frequency for mel bins (if <= 0, offset from Nyquist)"); + opts->Register("vtln-low", &vtln_low, + "Low inflection point in piecewise linear VTLN warping function"); + opts->Register("vtln-high", &vtln_high, + "High inflection point in piecewise linear VTLN warping function" + " (if negative, offset from high-mel-freq"); + opts->Register("debug-mel", &debug_mel, + "Print out debugging information for mel bin computation"); + } +}; + + +class MelBanks { + public: + + static inline BaseFloat InverseMelScale(BaseFloat mel_freq) { + return 700.0f * (expf (mel_freq / 1127.0f) - 1.0f); + } + + static inline BaseFloat MelScale(BaseFloat freq) { + return 1127.0f * logf (1.0f + freq / 700.0f); + } + + static BaseFloat VtlnWarpFreq(BaseFloat vtln_low_cutoff, + BaseFloat vtln_high_cutoff, // discontinuities in warp func + BaseFloat low_freq, + BaseFloat high_freq, // upper+lower frequency cutoffs in + // the mel computation + BaseFloat vtln_warp_factor, + BaseFloat freq); + + static BaseFloat VtlnWarpMelFreq(BaseFloat vtln_low_cutoff, + BaseFloat vtln_high_cutoff, + BaseFloat low_freq, + BaseFloat high_freq, + BaseFloat vtln_warp_factor, + BaseFloat mel_freq); + + + MelBanks(const MelBanksOptions &opts, + const FrameExtractionOptions &frame_opts, + BaseFloat vtln_warp_factor); + + /// Compute Mel energies (note: not log enerties). + /// At input, "fft_energies" contains the FFT energies (not log). + void Compute(const VectorBase &fft_energies, + VectorBase *mel_energies_out) const; + + int32 NumBins() const { return bins_.size(); } + + // returns vector of central freq of each bin; needed by plp code. + const Vector &GetCenterFreqs() const { return center_freqs_; } + + const std::vector > >& GetBins() const { + return bins_; + } + + // Copy constructor + MelBanks(const MelBanks &other); + private: + // Disallow assignment + MelBanks &operator = (const MelBanks &other); + + // center frequencies of bins, numbered from 0 ... num_bins-1. + // Needed by GetCenterFreqs(). + Vector center_freqs_; + + // the "bins_" vector is a vector, one for each bin, of a pair: + // (the first nonzero fft-bin), (the vector of weights). + std::vector > > bins_; + + bool debug_; + bool htk_mode_; +}; + + +// Compute liftering coefficients (scaling on cepstral coeffs) +// coeffs are numbered slightly differently from HTK: the zeroth +// index is C0, which is not affected. +void ComputeLifterCoeffs(BaseFloat Q, VectorBase *coeffs); + + +// Durbin's recursion - converts autocorrelation coefficients to the LPC +// pTmp - temporal place [n] +// pAC - autocorrelation coefficients [n + 1] +// pLP - linear prediction coefficients [n] (predicted_sn = sum_1^P{a[i-1] * s[n-i]}}) +// F(z) = 1 / (1 - A(z)), 1 is not stored in the denominator +// Returns log energy of residual (I think) +BaseFloat Durbin(int n, const BaseFloat *pAC, BaseFloat *pLP, BaseFloat *pTmp); + +// Compute LP coefficients from autocorrelation coefficients. +// Returns log energy of residual (I think) +BaseFloat ComputeLpc(const VectorBase &autocorr_in, + Vector *lpc_out); + +void Lpc2Cepstrum(int n, const BaseFloat *pLPC, BaseFloat *pCepst); + + + +void GetEqualLoudnessVector(const MelBanks &mel_banks, + Vector *ans); + +/// @} End of "addtogroup feat" +} // namespace kaldi + +#endif // KALDI_FEAT_MEL_COMPUTATIONS_H_ diff --git a/torchaudio/csrc/kaldi/feat/online-feature.h b/torchaudio/csrc/kaldi/feat/online-feature.h new file mode 100644 index 00000000000..b9dfcc0171e --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/online-feature.h @@ -0,0 +1,632 @@ +// feat/online-feature.h + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) +// 2014 Yanqing Sun, Junjie Wang, +// Daniel Povey, Korbinian Riedhammer + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_FEAT_ONLINE_FEATURE_H_ +#define KALDI_FEAT_ONLINE_FEATURE_H_ + +#include +#include +#include + +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" +#include "feat/feature-functions.h" +#include "feat/feature-mfcc.h" +#include "feat/feature-plp.h" +#include "feat/feature-fbank.h" +#include "itf/online-feature-itf.h" + +namespace kaldi { +/// @addtogroup onlinefeat OnlineFeatureExtraction +/// @{ + + +/// This class serves as a storage for feature vectors with an option to limit +/// the memory usage by removing old elements. The deleted frames indices are +/// "remembered" so that regardless of the MAX_ITEMS setting, the user always +/// provides the indices as if no deletion was being performed. +/// This is useful when processing very long recordings which would otherwise +/// cause the memory to eventually blow up when the features are not being removed. +class RecyclingVector { +public: + /// By default it does not remove any elements. + RecyclingVector(int items_to_hold = -1); + + /// The ownership is being retained by this collection - do not delete the item. + Vector *At(int index) const; + + /// The ownership of the item is passed to this collection - do not delete the item. + void PushBack(Vector *item); + + /// This method returns the size as if no "recycling" had happened, + /// i.e. equivalent to the number of times the PushBack method has been called. + int Size() const; + + ~RecyclingVector(); + +private: + std::deque*> items_; + int items_to_hold_; + int first_available_index_; +}; + + +/// This is a templated class for online feature extraction; +/// it's templated on a class like MfccComputer or PlpComputer +/// that does the basic feature extraction. +template +class OnlineGenericBaseFeature: public OnlineBaseFeature { + public: + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const { return computer_.Dim(); } + + // Note: IsLastFrame() will only ever return true if you have called + // InputFinished() (and this frame is the last frame). + virtual bool IsLastFrame(int32 frame) const { + return input_finished_ && frame == NumFramesReady() - 1; + } + virtual BaseFloat FrameShiftInSeconds() const { + return computer_.GetFrameOptions().frame_shift_ms / 1000.0f; + } + + virtual int32 NumFramesReady() const { return features_.Size(); } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + // Next, functions that are not in the interface. + + + // Constructor from options class + explicit OnlineGenericBaseFeature(const typename C::Options &opts); + + // This would be called from the application, when you get + // more wave data. Note: the sampling_rate is only provided so + // the code can assert that it matches the sampling rate + // expected in the options. + virtual void AcceptWaveform(BaseFloat sampling_rate, + const VectorBase &waveform); + + + // InputFinished() tells the class you won't be providing any + // more waveform. This will help flush out the last frame or two + // of features, in the case where snip-edges == false; it also + // affects the return value of IsLastFrame(). + virtual void InputFinished(); + + private: + // This function computes any additional feature frames that it is possible to + // compute from 'waveform_remainder_', which at this point may contain more + // than just a remainder-sized quantity (because AcceptWaveform() appends to + // waveform_remainder_ before calling this function). It adds these feature + // frames to features_, and shifts off any now-unneeded samples of input from + // waveform_remainder_ while incrementing waveform_offset_ by the same amount. + void ComputeFeatures(); + + void MaybeCreateResampler(BaseFloat sampling_rate); + + C computer_; // class that does the MFCC or PLP or filterbank computation + + // resampler in cases when the input sampling frequency is not equal to + // the expected sampling rate + std::unique_ptr resampler_; + + FeatureWindowFunction window_function_; + + // features_ is the Mfcc or Plp or Fbank features that we have already computed. + + RecyclingVector features_; + + // True if the user has called "InputFinished()" + bool input_finished_; + + // The sampling frequency, extracted from the config. Should + // be identical to the waveform supplied. + BaseFloat sampling_frequency_; + + // waveform_offset_ is the number of samples of waveform that we have + // already discarded, i.e. that were prior to 'waveform_remainder_'. + int64 waveform_offset_; + + // waveform_remainder_ is a short piece of waveform that we may need to keep + // after extracting all the whole frames we can (whatever length of feature + // will be required for the next phase of computation). + Vector waveform_remainder_; +}; + +typedef OnlineGenericBaseFeature OnlineMfcc; +typedef OnlineGenericBaseFeature OnlinePlp; +typedef OnlineGenericBaseFeature OnlineFbank; + + +/// This class takes a Matrix and wraps it as an +/// OnlineFeatureInterface: this can be useful where some earlier stage of +/// feature processing has been done offline but you want to use part of the +/// online pipeline. +class OnlineMatrixFeature: public OnlineFeatureInterface { + public: + /// Caution: this class maintains the const reference from the constructor, so + /// don't let it go out of scope while this object exists. + explicit OnlineMatrixFeature(const MatrixBase &mat): mat_(mat) { } + + virtual int32 Dim() const { return mat_.NumCols(); } + + virtual BaseFloat FrameShiftInSeconds() const { + return 0.01f; + } + + virtual int32 NumFramesReady() const { return mat_.NumRows(); } + + virtual void GetFrame(int32 frame, VectorBase *feat) { + feat->CopyFromVec(mat_.Row(frame)); + } + + virtual bool IsLastFrame(int32 frame) const { + return (frame + 1 == mat_.NumRows()); + } + + + private: + const MatrixBase &mat_; +}; + + +// Note the similarity with SlidingWindowCmnOptions, but there +// are also differences. One which doesn't appear in the config +// itself, because it's a difference between the setups, is that +// in OnlineCmn, we carry over data from the previous utterance, +// or, if no previous utterance is available, from global stats, +// or, if previous utterances are available but the total amount +// of data is less than prev_frames, we pad with up to "global_frames" +// frames from the global stats. +struct OnlineCmvnOptions { + int32 cmn_window; + int32 speaker_frames; // must be <= cmn_window + int32 global_frames; // must be <= speaker_frames. + bool normalize_mean; // Must be true if normalize_variance==true. + bool normalize_variance; + + int32 modulus; // not configurable from command line, relates to how the + // class computes the cmvn internally. smaller->more + // time-efficient but less memory-efficient. Must be >= 1. + int32 ring_buffer_size; // not configurable from command line; size of ring + // buffer used for caching CMVN stats. Must be >= + // modulus. + std::string skip_dims; // Colon-separated list of dimensions to skip normalization + // of, e.g. 13:14:15. + + OnlineCmvnOptions(): + cmn_window(600), + speaker_frames(600), + global_frames(200), + normalize_mean(true), + normalize_variance(false), + modulus(20), + ring_buffer_size(20), + skip_dims("") { } + + void Check() const { + KALDI_ASSERT(speaker_frames <= cmn_window && global_frames <= speaker_frames + && modulus > 0); + } + + void Register(OptionsItf *po) { + po->Register("cmn-window", &cmn_window, "Number of frames of sliding " + "context for cepstral mean normalization."); + po->Register("global-frames", &global_frames, "Number of frames of " + "global-average cepstral mean normalization stats to use for " + "first utterance of a speaker"); + po->Register("speaker-frames", &speaker_frames, "Number of frames of " + "previous utterance(s) from this speaker to use in cepstral " + "mean normalization"); + // we name the config string "norm-vars" for compatibility with + // ../featbin/apply-cmvn.cc + po->Register("norm-vars", &normalize_variance, "If true, do " + "cepstral variance normalization in addition to cepstral mean " + "normalization "); + po->Register("norm-means", &normalize_mean, "If true, do mean normalization " + "(note: you cannot normalize the variance but not the mean)"); + po->Register("skip-dims", &skip_dims, "Dimensions to skip normalization of " + "(colon-separated list of integers)");} +}; + + + +/** Struct OnlineCmvnState stores the state of CMVN adaptation between + utterances (but not the state of the computation within an utterance). It + stores the global CMVN stats and the stats of the current speaker (if we + have seen previous utterances for this speaker), and possibly will have a + member "frozen_state": if the user has called the function Freeze() of class + OnlineCmvn, to fix the CMVN so we can estimate fMLLR on top of the fixed + value of cmvn. If nonempty, "frozen_state" will reflect how we were + normalizing the mean and (if applicable) variance at the time when that + function was called. +*/ +struct OnlineCmvnState { + // The following is the total CMVN stats for this speaker (up till now), in + // the same format. + Matrix speaker_cmvn_stats; + + // The following is the global CMVN stats, in the usual + // format, of dimension 2 x (dim+1), as [ sum-stats count + // sum-squared-stats 0 ] + Matrix global_cmvn_stats; + + // If nonempty, contains CMVN stats representing the "frozen" state + // of CMVN that reflects how we were normalizing the data when the + // user called the Freeze() function in class OnlineCmvn. + Matrix frozen_state; + + OnlineCmvnState() { } + + explicit OnlineCmvnState(const Matrix &global_stats): + global_cmvn_stats(global_stats) { } + + // Copy constructor + OnlineCmvnState(const OnlineCmvnState &other); + + void Write(std::ostream &os, bool binary) const; + void Read(std::istream &is, bool binary); + + // Use the default assignment operator. +}; + +/** + This class does an online version of the cepstral mean and [optionally] + variance, but note that this is not equivalent to the offline version. This + is necessarily so, as the offline computation involves looking into the + future. If you plan to use features normalized with this type of CMVN then + you need to train in a `matched' way, i.e. with the same type of features. + We normally only do so in the "online" GMM-based decoding, e.g. in + online2bin/online2-wav-gmm-latgen-faster.cc; see also the script + steps/online/prepare_online_decoding.sh and steps/online/decode.sh. + + In the steady state (in the middle of a long utterance), this class + accumulates CMVN statistics from the previous "cmn_window" frames (default 600 + frames, or 6 seconds), and uses these to normalize the mean and possibly + variance of the current frame. + + The config variables "speaker_frames" and "global_frames" relate to what + happens at the beginning of the utterance when we have seen fewer than + "cmn_window" frames of context, and so might not have very good stats to + normalize with. Basically, we first augment any existing stats with up + to "speaker_frames" frames of stats from previous utterances of the current + speaker, and if this doesn't take us up to the required "cmn_window" frame + count, we further augment with up to "global_frames" frames of global + stats. The global stats are CMVN stats accumulated from training or testing + data, that give us a reasonable source of mean and variance for "typical" + data. + */ +class OnlineCmvn: public OnlineFeatureInterface { + public: + + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const { return src_->Dim(); } + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + // The online cmvn does not introduce any additional latency. + virtual int32 NumFramesReady() const { return src_->NumFramesReady(); } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + // + // Next, functions that are not in the interface. + // + + /// Initializer that sets the cmvn state. If you don't have previous + /// utterances from the same speaker you are supposed to initialize the CMVN + /// state from some global CMVN stats, which you can get from summing all cmvn + /// stats you have in your training data using "sum-matrix". This just gives + /// it a reasonable starting point at the start of the file. + /// If you do have previous utterances from the same speaker or at least a + /// similar environment, you are supposed to initialize it by calling GetState + /// from the previous utterance + OnlineCmvn(const OnlineCmvnOptions &opts, + const OnlineCmvnState &cmvn_state, + OnlineFeatureInterface *src); + + /// Initializer that does not set the cmvn state: + /// after calling this, you should call SetState(). + OnlineCmvn(const OnlineCmvnOptions &opts, + OnlineFeatureInterface *src); + + // Outputs any state information from this utterance to "cmvn_state". + // The value of "cmvn_state" before the call does not matter: the output + // depends on the value of OnlineCmvnState the class was initialized + // with, the input feature values up to cur_frame, and the effects + // of the user possibly having called Freeze(). + // If cur_frame is -1, it will just output the unmodified original + // state that was supplied to this object. + void GetState(int32 cur_frame, + OnlineCmvnState *cmvn_state); + + // This function can be used to modify the state of the CMVN computation + // from outside, but must only be called before you have processed any data + // (otherwise it will crash). This "state" is really just the information + // that is propagated between utterances, not the state of the computation + // inside an utterance. + void SetState(const OnlineCmvnState &cmvn_state); + + // From this point it will freeze the CMN to what it would have been if + // measured at frame "cur_frame", and it will stop it from changing + // further. This also applies retroactively for this utterance, so if you + // call GetFrame() on previous frames, it will use the CMVN stats + // from cur_frame; and it applies in the future too if you then + // call OutputState() and use this state to initialize the next + // utterance's CMVN object. + void Freeze(int32 cur_frame); + + virtual ~OnlineCmvn(); + protected: + + /// Smooth the CMVN stats "stats" (which are stored in the normal format as a + /// 2 x (dim+1) matrix), by possibly adding some stats from "global_stats" + /// and/or "speaker_stats", controlled by the config. The best way to + /// understand the smoothing rule we use is just to look at the code. + static void SmoothOnlineCmvnStats(const MatrixBase &speaker_stats, + const MatrixBase &global_stats, + const OnlineCmvnOptions &opts, + MatrixBase *stats); + + /// Get the most recent cached frame of CMVN stats. [If no frames + /// were cached, sets up empty stats for frame zero and returns that]. + void GetMostRecentCachedFrame(int32 frame, + int32 *cached_frame, + MatrixBase *stats); + + /// Cache this frame of stats. + void CacheFrame(int32 frame, const MatrixBase &stats); + + /// Initialize ring buffer for caching stats. + inline void InitRingBufferIfNeeded(); + + /// Computes the raw CMVN stats for this frame, making use of (and updating if + /// necessary) the cached statistics in raw_stats_. This means the (x, + /// x^2, count) stats for the last up to opts_.cmn_window frames. + void ComputeStatsForFrame(int32 frame, + MatrixBase *stats); + + + OnlineCmvnOptions opts_; + std::vector skip_dims_; // Skip CMVN for these dimensions. Derived from opts_. + OnlineCmvnState orig_state_; // reflects the state before we saw this + // utterance. + Matrix frozen_state_; // If the user called Freeze(), this variable + // will reflect the CMVN state that we froze + // at. + + // The variable below reflects the raw (count, x, x^2) statistics of the + // input, computed every opts_.modulus frames. raw_stats_[n / opts_.modulus] + // contains the (count, x, x^2) statistics for the frames from + // std::max(0, n - opts_.cmn_window) through n. + std::vector*> cached_stats_modulo_; + // the variable below is a ring-buffer of cached stats. the int32 is the + // frame index. + std::vector > > cached_stats_ring_; + + // Some temporary variables used inside functions of this class, which + // put here to avoid reallocation. + Matrix temp_stats_; + Vector temp_feats_; + Vector temp_feats_dbl_; + + OnlineFeatureInterface *src_; // Not owned here +}; + + +struct OnlineSpliceOptions { + int32 left_context; + int32 right_context; + OnlineSpliceOptions(): left_context(4), right_context(4) { } + void Register(OptionsItf *po) { + po->Register("left-context", &left_context, "Left-context for frame " + "splicing prior to LDA"); + po->Register("right-context", &right_context, "Right-context for frame " + "splicing prior to LDA"); + } +}; + +class OnlineSpliceFrames: public OnlineFeatureInterface { + public: + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const { + return src_->Dim() * (1 + left_context_ + right_context_); + } + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const; + + virtual void GetFrame(int32 frame, VectorBase *feat); + + // + // Next, functions that are not in the interface. + // + OnlineSpliceFrames(const OnlineSpliceOptions &opts, + OnlineFeatureInterface *src): + left_context_(opts.left_context), right_context_(opts.right_context), + src_(src) { } + + private: + int32 left_context_; + int32 right_context_; + OnlineFeatureInterface *src_; // Not owned here +}; + +/// This online-feature class implements any affine or linear transform. +class OnlineTransform: public OnlineFeatureInterface { + public: + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const { return offset_.Dim(); } + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const { return src_->NumFramesReady(); } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats); + + // + // Next, functions that are not in the interface. + // + + /// The transform can be a linear transform, or an affine transform + /// where the last column is the offset. + OnlineTransform(const MatrixBase &transform, + OnlineFeatureInterface *src); + + + private: + OnlineFeatureInterface *src_; // Not owned here + Matrix linear_term_; + Vector offset_; +}; + +class OnlineDeltaFeature: public OnlineFeatureInterface { + public: + // + // First, functions that are present in the interface: + // + virtual int32 Dim() const; + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const; + + virtual void GetFrame(int32 frame, VectorBase *feat); + + // + // Next, functions that are not in the interface. + // + OnlineDeltaFeature(const DeltaFeaturesOptions &opts, + OnlineFeatureInterface *src); + + private: + OnlineFeatureInterface *src_; // Not owned here + DeltaFeaturesOptions opts_; + DeltaFeatures delta_features_; // This class contains just a few + // coefficients. +}; + + +/// This feature type can be used to cache its input, to avoid +/// repetition of computation in a multi-pass decoding context. +class OnlineCacheFeature: public OnlineFeatureInterface { + public: + virtual int32 Dim() const { return src_->Dim(); } + + virtual bool IsLastFrame(int32 frame) const { + return src_->IsLastFrame(frame); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const { return src_->NumFramesReady(); } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats); + + virtual ~OnlineCacheFeature() { ClearCache(); } + + // Things that are not in the shared interface: + + void ClearCache(); // this should be called if you change the underlying + // features in some way. + + explicit OnlineCacheFeature(OnlineFeatureInterface *src): src_(src) { } + private: + + OnlineFeatureInterface *src_; // Not owned here + std::vector* > cache_; +}; + + + + +/// This online-feature class implements combination of two feature +/// streams (such as pitch, plp) into one stream. +class OnlineAppendFeature: public OnlineFeatureInterface { + public: + virtual int32 Dim() const { return src1_->Dim() + src2_->Dim(); } + + virtual bool IsLastFrame(int32 frame) const { + return (src1_->IsLastFrame(frame) || src2_->IsLastFrame(frame)); + } + // Hopefully sources have the same rate + virtual BaseFloat FrameShiftInSeconds() const { + return src1_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const { + return std::min(src1_->NumFramesReady(), src2_->NumFramesReady()); + } + + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual ~OnlineAppendFeature() { } + + OnlineAppendFeature(OnlineFeatureInterface *src1, + OnlineFeatureInterface *src2): src1_(src1), src2_(src2) { } + private: + + OnlineFeatureInterface *src1_; + OnlineFeatureInterface *src2_; +}; + +/// @} End of "addtogroup onlinefeat" +} // namespace kaldi + +#endif // KALDI_FEAT_ONLINE_FEATURE_H_ diff --git a/torchaudio/csrc/kaldi/feat/pitch-functions.cc b/torchaudio/csrc/kaldi/feat/pitch-functions.cc new file mode 100644 index 00000000000..430e9bdb53a --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/pitch-functions.cc @@ -0,0 +1,1667 @@ +// feat/pitch-functions.cc + +// Copyright 2013 Pegah Ghahremani +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2014 Yanqing Sun, Junjie Wang, +// Daniel Povey, Korbinian Riedhammer +// Xin Lei + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "feat/feature-functions.h" +#include "feat/mel-computations.h" +#include "feat/online-feature.h" +#include "feat/pitch-functions.h" +#include "feat/resample.h" +#include "matrix/matrix-functions.h" + +namespace kaldi { + +/** + This function processes the NCCF n to a POV feature f by applying the formula + f = (1.0001 - n)^0.15 - 1.0 + This is a nonlinear function designed to make the output reasonably Gaussian + distributed. Before doing this, the NCCF distribution is in the range [-1, + 1] but has a strong peak just before 1.0, which this function smooths out. +*/ + +BaseFloat NccfToPovFeature(BaseFloat n) { + if (n > 1.0) { + n = 1.0; + } else if (n < -1.0) { + n = -1.0; + } + BaseFloat f = pow((1.0001 - n), 0.15) - 1.0; + KALDI_ASSERT(f - f == 0); // check for NaN,inf. + return f; +} + +/** + This function processes the NCCF n to a reasonably accurate probability + of voicing p by applying the formula: + + n' = fabs(n) + r = -5.2 + 5.4 * exp(7.5 * (n' - 1.0)) + + 4.8 * n' - 2.0 * exp(-10.0 * n') + 4.2 * exp(20.0 * (n' - 1.0)); + p = 1.0 / (1 + exp(-1.0 * r)); + + How did we get this formula? We plotted the empirical log-prob-ratio of voicing + r = log( p[voiced] / p[not-voiced] ) + [on the Keele database where voicing is marked], as a function of the NCCF at + the delay picked by our algorithm. This was done on intervals of the NCCF, so + we had enough statistics to get that ratio. The NCCF covers [-1, 1]; almost + all of the probability mass is on [0, 1] but the empirical POV seems fairly + symmetric with a minimum near zero, so we chose to make it a function of n' = fabs(n). + + Then we manually tuned a function (the one you see above) that approximated + the log-prob-ratio of voicing fairly well as a function of the absolute-value + NCCF n'; however, wasn't a very exact match since we were also trying to make + the transformed NCCF fairly Gaussian distributed, with a view to using it as + a feature-- an idea we later abandoned after a simpler formula worked better. + */ +BaseFloat NccfToPov(BaseFloat n) { + BaseFloat ndash = fabs(n); + if (ndash > 1.0) ndash = 1.0; // just in case it was slightly outside [-1, 1] + + BaseFloat r = -5.2 + 5.4 * Exp(7.5 * (ndash - 1.0)) + 4.8 * ndash - + 2.0 * Exp(-10.0 * ndash) + 4.2 * Exp(20.0 * (ndash - 1.0)); + // r is the approximate log-prob-ratio of voicing, log(p/(1-p)). + BaseFloat p = 1.0 / (1 + Exp(-1.0 * r)); + KALDI_ASSERT(p - p == 0); // Check for NaN/inf + return p; +} + +/** + This function computes some dot products that are required + while computing the NCCF. + For each integer lag from start to end-1, this function + outputs to (*inner_prod)(lag - start), the dot-product + of a window starting at 0 with a window starting at + lag. All windows are of length nccf_window_size. It + outputs to (*norm_prod)(lag - start), e1 * e2, where + e1 is the dot-product of the un-shifted window with itself, + and d2 is the dot-product of the window shifted by "lag" + with itself. + */ +void ComputeCorrelation(const VectorBase &wave, + int32 first_lag, int32 last_lag, + int32 nccf_window_size, + VectorBase *inner_prod, + VectorBase *norm_prod) { + Vector zero_mean_wave(wave); + // TODO: possibly fix this, the mean normalization is done in a strange way. + SubVector wave_part(wave, 0, nccf_window_size); + // subtract mean-frame from wave + zero_mean_wave.Add(-wave_part.Sum() / nccf_window_size); + BaseFloat e1, e2, sum; + SubVector sub_vec1(zero_mean_wave, 0, nccf_window_size); + e1 = VecVec(sub_vec1, sub_vec1); + for (int32 lag = first_lag; lag <= last_lag; lag++) { + SubVector sub_vec2(zero_mean_wave, lag, nccf_window_size); + e2 = VecVec(sub_vec2, sub_vec2); + sum = VecVec(sub_vec1, sub_vec2); + (*inner_prod)(lag - first_lag) = sum; + (*norm_prod)(lag - first_lag) = e1 * e2; + } +} + +/** + Computes the NCCF as a fraction of the numerator term (a dot product between + two vectors) and a denominator term which equals sqrt(e1*e2 + nccf_ballast) + where e1 and e2 are both dot-products of bits of the wave with themselves, + and e1*e2 is supplied as "norm_prod". These quantities are computed by + "ComputeCorrelation". +*/ +void ComputeNccf(const VectorBase &inner_prod, + const VectorBase &norm_prod, + BaseFloat nccf_ballast, + VectorBase *nccf_vec) { + KALDI_ASSERT(inner_prod.Dim() == norm_prod.Dim() && + inner_prod.Dim() == nccf_vec->Dim()); + for (int32 lag = 0; lag < inner_prod.Dim(); lag++) { + BaseFloat numerator = inner_prod(lag), + denominator = pow(norm_prod(lag) + nccf_ballast, 0.5), + nccf; + if (denominator != 0.0) { + nccf = numerator / denominator; + } else { + KALDI_ASSERT(numerator == 0.0); + nccf = 0.0; + } + KALDI_ASSERT(nccf < 1.01 && nccf > -1.01); + (*nccf_vec)(lag) = nccf; + } +} + +/** + This function selects the lags at which we measure the NCCF: we need + to select lags from 1/max_f0 to 1/min_f0, in a geometric progression + with ratio 1 + d. + */ +void SelectLags(const PitchExtractionOptions &opts, + Vector *lags) { + // choose lags relative to acceptable pitch tolerance + BaseFloat min_lag = 1.0 / opts.max_f0, max_lag = 1.0 / opts.min_f0; + + std::vector tmp_lags; + for (BaseFloat lag = min_lag; lag <= max_lag; lag *= 1.0 + opts.delta_pitch) + tmp_lags.push_back(lag); + lags->Resize(tmp_lags.size()); + std::copy(tmp_lags.begin(), tmp_lags.end(), lags->Data()); +} + + +/** + This function computes the local-cost for the Viterbi computation, + see eq. (5) in the paper. + @param opts The options as provided by the user + @param nccf_pitch The nccf as computed for the pitch computation (with ballast). + @param lags The log-spaced lags at which nccf_pitch is sampled. + @param local_cost We output the local-cost to here. +*/ +void ComputeLocalCost(const VectorBase &nccf_pitch, + const VectorBase &lags, + const PitchExtractionOptions &opts, + VectorBase *local_cost) { + // from the paper, eq. 5, local_cost = 1 - Phi(t,i)(1 - soft_min_f0 L_i) + // nccf is the nccf on this frame measured at the lags in "lags". + KALDI_ASSERT(nccf_pitch.Dim() == local_cost->Dim() && + nccf_pitch.Dim() == lags.Dim()); + local_cost->Set(1.0); + // add the term -Phi(t,i): + local_cost->AddVec(-1.0, nccf_pitch); + // add the term soft_min_f0 Phi(t,i) L_i + local_cost->AddVecVec(opts.soft_min_f0, lags, nccf_pitch, 1.0); +} + + + +// class PitchFrameInfo is used inside class OnlinePitchFeatureImpl. +// It stores the information we need to keep around for a single frame +// of the pitch computation. +class PitchFrameInfo { + public: + /// This function resizes the arrays for this object and updates the reference + /// counts for the previous object (by decrementing those reference counts + /// when we destroy a StateInfo object). A StateInfo object is considered to + /// be destroyed when we delete it, not when its reference counts goes to + /// zero. + void Cleanup(PitchFrameInfo *prev_frame); + + /// This function may be called for the last (most recent) PitchFrameInfo + /// object with the best state (obtained from the externally held + /// forward-costs). It traces back as far as needed to set the + /// cur_best_state_, and as it's going it sets the lag-index and pov_nccf in + /// pitch_pov_iter, which when it's called is an iterator to where to put the + /// info for the final state; the iterator will be decremented inside this + /// function. + void SetBestState(int32 best_state, + std::vector > &lag_nccf); + + /// This function may be called on the last (most recent) PitchFrameInfo + /// object; it computes how many frames of latency there is because the + /// traceback has not yet settled on a single value for frames in the past. + /// It actually returns the minimum of max_latency and the actual latency, + /// which is an optimization because we won't care about latency past + /// a user-specified maximum latency. + int32 ComputeLatency(int32 max_latency); + + /// This function updates + bool UpdatePreviousBestState(PitchFrameInfo *prev_frame); + + /// This constructor is used for frame -1; it sets the costs to be all zeros + /// the pov_nccf's to zero and the backpointers to -1. + explicit PitchFrameInfo(int32 num_states); + + /// This constructor is used for subsequent frames (not -1). + PitchFrameInfo(PitchFrameInfo *prev); + + /// Record the nccf_pov value. + /// @param nccf_pov The nccf as computed for the POV computation (without ballast). + void SetNccfPov(const VectorBase &nccf_pov); + + /// This constructor is used for frames apart from frame -1; the bulk of + /// the Viterbi computation takes place inside this constructor. + /// @param opts The options as provided by the user + /// @param nccf_pitch The nccf as computed for the pitch computation + /// (with ballast). + /// @param nccf_pov The nccf as computed for the POV computation + /// (without ballast). + /// @param lags The log-spaced lags at which nccf_pitch and + /// nccf_pov are sampled. + /// @param prev_frame_forward_cost The forward-cost vector for the + /// previous frame. + /// @param index_info A pointer to a temporary vector used by this function + /// @param this_forward_cost The forward-cost vector for this frame + /// (to be computed). + void ComputeBacktraces(const PitchExtractionOptions &opts, + const VectorBase &nccf_pitch, + const VectorBase &lags, + const VectorBase &prev_forward_cost, + std::vector > *index_info, + VectorBase *this_forward_cost); + private: + // struct StateInfo is the information we keep for a single one of the + // log-spaced lags, for a single frame. This is a state in the Viterbi + // computation. + struct StateInfo { + /// The state index on the previous frame that is the best preceding state + /// for this state. + int32 backpointer; + /// the version of the NCCF we keep for the POV computation (without the + /// ballast term). + BaseFloat pov_nccf; + StateInfo(): backpointer(0), pov_nccf(0.0) { } + }; + std::vector state_info_; + /// the state index of the first entry in "state_info"; this will initially be + /// zero, but after cleanup might be nonzero. + int32 state_offset_; + + /// The current best state in the backtrace from the end. + int32 cur_best_state_; + + /// The structure for the previous frame. + PitchFrameInfo *prev_info_; +}; + + +// This constructor is used for frame -1; it sets the costs to be all zeros +// the pov_nccf's to zero and the backpointers to -1. +PitchFrameInfo::PitchFrameInfo(int32 num_states) + :state_info_(num_states), state_offset_(0), + cur_best_state_(-1), prev_info_(NULL) { } + + +bool pitch_use_naive_search = false; // This is used in unit-tests. + + +PitchFrameInfo::PitchFrameInfo(PitchFrameInfo *prev_info): + state_info_(prev_info->state_info_.size()), state_offset_(0), + cur_best_state_(-1), prev_info_(prev_info) { } + +void PitchFrameInfo::SetNccfPov(const VectorBase &nccf_pov) { + int32 num_states = nccf_pov.Dim(); + KALDI_ASSERT(num_states == state_info_.size()); + for (int32 i = 0; i < num_states; i++) + state_info_[i].pov_nccf = nccf_pov(i); +} + +void PitchFrameInfo::ComputeBacktraces( + const PitchExtractionOptions &opts, + const VectorBase &nccf_pitch, + const VectorBase &lags, + const VectorBase &prev_forward_cost_vec, + std::vector > *index_info, + VectorBase *this_forward_cost_vec) { + int32 num_states = nccf_pitch.Dim(); + + Vector local_cost(num_states, kUndefined); + ComputeLocalCost(nccf_pitch, lags, opts, &local_cost); + + const BaseFloat delta_pitch_sq = pow(Log(1.0 + opts.delta_pitch), 2.0), + inter_frame_factor = delta_pitch_sq * opts.penalty_factor; + + // index local_cost, prev_forward_cost and this_forward_cost using raw pointer + // indexing not operator (), since this is the very inner loop and a lot of + // time is taken here. + const BaseFloat *prev_forward_cost = prev_forward_cost_vec.Data(); + BaseFloat *this_forward_cost = this_forward_cost_vec->Data(); + + if (index_info->empty()) + index_info->resize(num_states); + + // make it a reference for more concise indexing. + std::vector > &bounds = *index_info; + + /* bounds[i].first will be a lower bound on the backpointer for state i, + bounds[i].second will be an upper bound on it. We progressively tighten + these bounds till we know the backpointers exactly. + */ + + if (pitch_use_naive_search) { + // This branch is only taken in unit-testing code. + for (int32 i = 0; i < num_states; i++) { + BaseFloat best_cost = std::numeric_limits::infinity(); + int32 best_j = -1; + for (int32 j = 0; j < num_states; j++) { + BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor + + prev_forward_cost[j]; + if (this_cost < best_cost) { + best_cost = this_cost; + best_j = j; + } + } + this_forward_cost[i] = best_cost; + state_info_[i].backpointer = best_j; + } + } else { + int32 last_backpointer = 0; + for (int32 i = 0; i < num_states; i++) { + int32 start_j = last_backpointer; + BaseFloat best_cost = (start_j - i) * (start_j - i) * inter_frame_factor + + prev_forward_cost[start_j]; + int32 best_j = start_j; + + for (int32 j = start_j + 1; j < num_states; j++) { + BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor + + prev_forward_cost[j]; + if (this_cost < best_cost) { + best_cost = this_cost; + best_j = j; + } else { // as soon as the costs stop improving, we stop searching. + break; // this is a loose lower bound we're getting. + } + } + state_info_[i].backpointer = best_j; + this_forward_cost[i] = best_cost; + bounds[i].first = best_j; // this is now a lower bound on the + // backpointer. + bounds[i].second = num_states - 1; // we have no meaningful upper bound + // yet. + last_backpointer = best_j; + } + + // We iterate, progressively refining the upper and lower bounds until they + // meet and we know that the resulting backtraces are optimal. Each + // iteration takes time linear in num_states. We won't normally iterate as + // far as num_states; normally we only do two iterations; when printing out + // the number of iterations, it's rarely more than that (once I saw seven + // iterations). Anyway, this part of the computation does not dominate. + for (int32 iter = 0; iter < num_states; iter++) { + bool changed = false; + if (iter % 2 == 0) { // go backwards through the states + last_backpointer = num_states - 1; + for (int32 i = num_states - 1; i >= 0; i--) { + int32 lower_bound = bounds[i].first, + upper_bound = std::min(last_backpointer, bounds[i].second); + if (upper_bound == lower_bound) { + last_backpointer = lower_bound; + continue; + } + BaseFloat best_cost = this_forward_cost[i]; + int32 best_j = state_info_[i].backpointer, initial_best_j = best_j; + + if (best_j == upper_bound) { + // if best_j already equals upper bound, don't bother tightening the + // upper bound, we'll tighten the lower bound when the time comes. + last_backpointer = best_j; + continue; + } + // Below, we have j > lower_bound + 1 because we know we've already + // evaluated lower_bound and lower_bound + 1 [via knowledge of + // this algorithm.] + for (int32 j = upper_bound; j > lower_bound + 1; j--) { + BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor + + prev_forward_cost[j]; + if (this_cost < best_cost) { + best_cost = this_cost; + best_j = j; + } else { // as soon as the costs stop improving, we stop searching, + // unless the best j is still lower than j, in which case + // we obviously need to keep moving. + if (best_j > j) + break; // this is a loose lower bound we're getting. + } + } + // our "best_j" is now an upper bound on the backpointer. + bounds[i].second = best_j; + if (best_j != initial_best_j) { + this_forward_cost[i] = best_cost; + state_info_[i].backpointer = best_j; + changed = true; + } + last_backpointer = best_j; + } + } else { // go forwards through the states. + last_backpointer = 0; + for (int32 i = 0; i < num_states; i++) { + int32 lower_bound = std::max(last_backpointer, bounds[i].first), + upper_bound = bounds[i].second; + if (upper_bound == lower_bound) { + last_backpointer = lower_bound; + continue; + } + BaseFloat best_cost = this_forward_cost[i]; + int32 best_j = state_info_[i].backpointer, initial_best_j = best_j; + + if (best_j == lower_bound) { + // if best_j already equals lower bound, we don't bother tightening + // the lower bound, we'll tighten the upper bound when the time + // comes. + last_backpointer = best_j; + continue; + } + // Below, we have j < upper_bound because we know we've already + // evaluated that point. + for (int32 j = lower_bound; j < upper_bound - 1; j++) { + BaseFloat this_cost = (j - i) * (j - i) * inter_frame_factor + + prev_forward_cost[j]; + if (this_cost < best_cost) { + best_cost = this_cost; + best_j = j; + } else { // as soon as the costs stop improving, we stop searching, + // unless the best j is still higher than j, in which case + // we obviously need to keep moving. + if (best_j < j) + break; // this is a loose lower bound we're getting. + } + } + // our "best_j" is now a lower bound on the backpointer. + bounds[i].first = best_j; + if (best_j != initial_best_j) { + this_forward_cost[i] = best_cost; + state_info_[i].backpointer = best_j; + changed = true; + } + last_backpointer = best_j; + } + } + if (!changed) + break; + } + } + // The next statement is needed due to RecomputeBacktraces: we have to + // invalidate the previously computed best-state info. + cur_best_state_ = -1; + this_forward_cost_vec->AddVec(1.0, local_cost); +} + +void PitchFrameInfo::SetBestState( + int32 best_state, + std::vector > &lag_nccf) { + + // This function would naturally be recursive, but we have coded this to avoid + // recursion, which would otherwise eat up the stack. Think of it as a static + // member function, except we do use "this" right at the beginning. + + std::vector >::reverse_iterator iter = lag_nccf.rbegin(); + + PitchFrameInfo *this_info = this; // it will change in the loop. + while (this_info != NULL) { + PitchFrameInfo *prev_info = this_info->prev_info_; + if (best_state == this_info->cur_best_state_) + return; // no change + if (prev_info != NULL) // don't write anything for frame -1. + iter->first = best_state; + size_t state_info_index = best_state - this_info->state_offset_; + KALDI_ASSERT(state_info_index < this_info->state_info_.size()); + this_info->cur_best_state_ = best_state; + best_state = this_info->state_info_[state_info_index].backpointer; + if (prev_info != NULL) // don't write anything for frame -1. + iter->second = this_info->state_info_[state_info_index].pov_nccf; + this_info = prev_info; + if (this_info != NULL) ++iter; + } +} + +int32 PitchFrameInfo::ComputeLatency(int32 max_latency) { + if (max_latency <= 0) return 0; + + int32 latency = 0; + + // This function would naturally be recursive, but we have coded this to avoid + // recursion, which would otherwise eat up the stack. Think of it as a static + // member function, except we do use "this" right at the beginning. + // This function is called only on the most recent PitchFrameInfo object. + int32 num_states = state_info_.size(); + int32 min_living_state = 0, max_living_state = num_states - 1; + PitchFrameInfo *this_info = this; // it will change in the loop. + + + for (; this_info != NULL && latency < max_latency;) { + int32 offset = this_info->state_offset_; + KALDI_ASSERT(min_living_state >= offset && + max_living_state - offset < this_info->state_info_.size()); + min_living_state = + this_info->state_info_[min_living_state - offset].backpointer; + max_living_state = + this_info->state_info_[max_living_state - offset].backpointer; + if (min_living_state == max_living_state) { + return latency; + } + this_info = this_info->prev_info_; + if (this_info != NULL) // avoid incrementing latency for frame -1, + latency++; // as it's not a real frame. + } + return latency; +} + +void PitchFrameInfo::Cleanup(PitchFrameInfo *prev_frame) { + KALDI_ERR << "Cleanup not implemented."; +} + + +// struct NccfInfo is used to cache certain quantities that we need for online +// operation, for the first "recompute_frame" frames of the file (e.g. 300); +// after that many frames, or after the user calls InputFinished(), we redo the +// initial backtraces, as we'll then have a better estimate of the average signal +// energy. +struct NccfInfo { + + Vector nccf_pitch_resampled; // resampled nccf_pitch + BaseFloat avg_norm_prod; // average value of e1 * e2. + BaseFloat mean_square_energy; // mean_square energy we used when computing the + // original ballast term for + // "nccf_pitch_resampled". + + NccfInfo(BaseFloat avg_norm_prod, + BaseFloat mean_square_energy): + avg_norm_prod(avg_norm_prod), + mean_square_energy(mean_square_energy) { } +}; + + + +// We could inherit from OnlineBaseFeature as we have the same interface, +// but this will unnecessary force a lot of our functions to be virtual. +class OnlinePitchFeatureImpl { + public: + explicit OnlinePitchFeatureImpl(const PitchExtractionOptions &opts); + + int32 Dim() const { return 2; } + + BaseFloat FrameShiftInSeconds() const; + + int32 NumFramesReady() const; + + bool IsLastFrame(int32 frame) const; + + void GetFrame(int32 frame, VectorBase *feat); + + void AcceptWaveform(BaseFloat sampling_rate, + const VectorBase &waveform); + + void InputFinished(); + + ~OnlinePitchFeatureImpl(); + + + // Copy-constructor, can be used to obtain a new copy of this object, + // any state from this utterance. + OnlinePitchFeatureImpl(const OnlinePitchFeatureImpl &other); + + private: + + /// This function works out from the signal how many frames are currently + /// available to process (this is called from inside AcceptWaveform()). + /// Note: the number of frames differs slightly from the number the + /// old pitch code gave. + /// Note: the number this returns depends on whether input_finished_ == true; + /// if it is, it will "force out" a final frame or two. + int32 NumFramesAvailable(int64 num_downsampled_samples, bool snip_edges) const; + + /// This function extracts from the signal the samples numbered from + /// "sample_index" (numbered in the full downsampled signal, not just this + /// part), and of length equal to window->Dim(). It uses the data members + /// downsampled_samples_discarded_ and downsampled_signal_remainder_, as well + /// as the more recent part of the downsampled wave "downsampled_wave_part" + /// which is provided. + /// + /// @param downsampled_wave_part One chunk of the downsampled wave, + /// starting from sample-index downsampled_samples_discarded_. + /// @param sample_index The desired starting sample index (measured from + /// the start of the whole signal, not just this part). + /// @param window The part of the signal is output to here. + void ExtractFrame(const VectorBase &downsampled_wave_part, + int64 frame_index, + VectorBase *window); + + + /// This function is called after we reach frame "recompute_frame", or when + /// InputFinished() is called, whichever comes sooner. It recomputes the + /// backtraces for frames zero through recompute_frame, if needed because the + /// average energy of the signal has changed, affecting the nccf ballast term. + /// It works out the average signal energy from + /// downsampled_samples_processed_, signal_sum_ and signal_sumsq_ (which, if + /// you see the calling code, might include more frames than just + /// "recompute_frame", it might include up to the end of the current chunk). + void RecomputeBacktraces(); + + + /// This function updates downsampled_signal_remainder_, + /// downsampled_samples_processed_, signal_sum_ and signal_sumsq_; it's called + /// from AcceptWaveform(). + void UpdateRemainder(const VectorBase &downsampled_wave_part); + + + // The following variables don't change throughout the lifetime + // of this object. + PitchExtractionOptions opts_; + + // the first lag of the downsampled signal at which we measure NCCF + int32 nccf_first_lag_; + // the last lag of the downsampled signal at which we measure NCCF + int32 nccf_last_lag_; + + // The log-spaced lags at which we will resample the NCCF + Vector lags_; + + // This object is used to resample from evenly spaced to log-evenly-spaced + // nccf values. It's a pointer for convenience of initialization, so we don't + // have to use the initializer from the constructor. + ArbitraryResample *nccf_resampler_; + + // The following objects may change during the lifetime of this object. + + // This object is used to resample the signal. + LinearResample *signal_resampler_; + + // frame_info_ is indexed by [frame-index + 1]. frame_info_[0] is an object + // that corresponds to frame -1, which is not a real frame. + std::vector frame_info_; + + + // nccf_info_ is indexed by frame-index, from frame 0 to at most + // opts_.recompute_frame - 1. It contains some information we'll + // need to recompute the tracebacks after getting a better estimate + // of the average energy of the signal. + std::vector nccf_info_; + + // Current number of frames which we can't output because Viterbi has not + // converged for them, or opts_.max_frames_latency if we have reached that + // limit. + int32 frames_latency_; + + // The forward-cost at the current frame (the last frame in frame_info_); + // this has the same dimension as lags_. We normalize each time so + // the lowest cost is zero, for numerical accuracy and so we can use float. + Vector forward_cost_; + + // stores the constant part of forward_cost_. + double forward_cost_remainder_; + + // The resampled-lag index and the NCCF (as computed for POV, without ballast + // term) for each frame, as determined by Viterbi traceback from the best + // final state. + std::vector > lag_nccf_; + + bool input_finished_; + + /// sum-squared of previously processed parts of signal; used to get NCCF + /// ballast term. Denominator is downsampled_samples_processed_. + double signal_sumsq_; + + /// sum of previously processed parts of signal; used to do mean-subtraction + /// when getting sum-squared, along with signal_sumsq_. + double signal_sum_; + + /// downsampled_samples_processed is the number of samples (after + /// downsampling) that we got in previous calls to AcceptWaveform(). + int64 downsampled_samples_processed_; + /// This is a small remainder of the previous downsampled signal; + /// it's used by ExtractFrame for frames near the boundary of two + /// waveforms supplied to AcceptWaveform(). + Vector downsampled_signal_remainder_; +}; + + +OnlinePitchFeatureImpl::OnlinePitchFeatureImpl( + const PitchExtractionOptions &opts): + opts_(opts), forward_cost_remainder_(0.0), input_finished_(false), + signal_sumsq_(0.0), signal_sum_(0.0), downsampled_samples_processed_(0) { + signal_resampler_ = new LinearResample(opts.samp_freq, opts.resample_freq, + opts.lowpass_cutoff, + opts.lowpass_filter_width); + + double outer_min_lag = 1.0 / opts.max_f0 - + (opts.upsample_filter_width/(2.0 * opts.resample_freq)); + double outer_max_lag = 1.0 / opts.min_f0 + + (opts.upsample_filter_width/(2.0 * opts.resample_freq)); + nccf_first_lag_ = ceil(opts.resample_freq * outer_min_lag); + nccf_last_lag_ = floor(opts.resample_freq * outer_max_lag); + + frames_latency_ = 0; // will be set in AcceptWaveform() + + // Choose the lags at which we resample the NCCF. + SelectLags(opts, &lags_); + + // upsample_cutoff is the filter cutoff for upsampling the NCCF, which is the + // Nyquist of the resampling frequency. The NCCF is (almost completely) + // bandlimited to around "lowpass_cutoff" (1000 by default), and when the + // spectrum of this bandlimited signal is convolved with the spectrum of an + // impulse train with frequency "resample_freq", which are separated by 4kHz, + // we get energy at -5000,-3000, -1000...1000, 3000..5000, etc. Filtering at + // half the Nyquist (2000 by default) is sufficient to get only the first + // repetition. + BaseFloat upsample_cutoff = opts.resample_freq * 0.5; + + + Vector lags_offset(lags_); + // lags_offset equals lags_ (which are the log-spaced lag values we want to + // measure the NCCF at) with nccf_first_lag_ / opts.resample_freq subtracted + // from each element, so we can treat the measured NCCF values as as starting + // from sample zero in a signal that starts at the point start / + // opts.resample_freq. This is necessary because the ArbitraryResample code + // assumes that the input signal starts from sample zero. + lags_offset.Add(-nccf_first_lag_ / opts.resample_freq); + + int32 num_measured_lags = nccf_last_lag_ + 1 - nccf_first_lag_; + + nccf_resampler_ = new ArbitraryResample(num_measured_lags, opts.resample_freq, + upsample_cutoff, lags_offset, + opts.upsample_filter_width); + + // add a PitchInfo object for frame -1 (not a real frame). + frame_info_.push_back(new PitchFrameInfo(lags_.Dim())); + // zeroes forward_cost_; this is what we want for the fake frame -1. + forward_cost_.Resize(lags_.Dim()); +} + + +int32 OnlinePitchFeatureImpl::NumFramesAvailable( + int64 num_downsampled_samples, bool snip_edges) const { + int32 frame_shift = opts_.NccfWindowShift(), + frame_length = opts_.NccfWindowSize(); + // Use the "full frame length" to compute the number + // of frames only if the input is not finished. + if (!input_finished_) + frame_length += nccf_last_lag_; + if (num_downsampled_samples < frame_length) { + return 0; + } else { + if (!snip_edges) { + if (input_finished_) { + return static_cast(num_downsampled_samples * 1.0f / + frame_shift + 0.5f); + } else { + return static_cast((num_downsampled_samples - frame_length / 2) * + 1.0f / frame_shift + 0.5f); + } + } else { + return static_cast((num_downsampled_samples - frame_length) / + frame_shift + 1); + } + } +} + +void OnlinePitchFeatureImpl::UpdateRemainder( + const VectorBase &downsampled_wave_part) { + // frame_info_ has an extra element at frame-1, so subtract + // one from the length. + int64 num_frames = static_cast(frame_info_.size()) - 1, + next_frame = num_frames, + frame_shift = opts_.NccfWindowShift(), + next_frame_sample = frame_shift * next_frame; + + signal_sumsq_ += VecVec(downsampled_wave_part, downsampled_wave_part); + signal_sum_ += downsampled_wave_part.Sum(); + + // next_frame_sample is the first sample index we'll need for the + // next frame. + int64 next_downsampled_samples_processed = + downsampled_samples_processed_ + downsampled_wave_part.Dim(); + + if (next_frame_sample > next_downsampled_samples_processed) { + // this could only happen in the weird situation that the full frame length + // is less than the frame shift. + int32 full_frame_length = opts_.NccfWindowSize() + nccf_last_lag_; + KALDI_ASSERT(full_frame_length < frame_shift && "Code error"); + downsampled_signal_remainder_.Resize(0); + } else { + Vector new_remainder(next_downsampled_samples_processed - + next_frame_sample); + // note: next_frame_sample is the index into the entire signal, of + // new_remainder(0). + // i is the absolute index of the signal. + for (int64 i = next_frame_sample; + i < next_downsampled_samples_processed; i++) { + if (i >= downsampled_samples_processed_) { // in current signal. + new_remainder(i - next_frame_sample) = + downsampled_wave_part(i - downsampled_samples_processed_); + } else { // in old remainder; only reach here if waveform supplied is + new_remainder(i - next_frame_sample) = // tiny. + downsampled_signal_remainder_(i - downsampled_samples_processed_ + + downsampled_signal_remainder_.Dim()); + } + } + downsampled_signal_remainder_.Swap(&new_remainder); + } + downsampled_samples_processed_ = next_downsampled_samples_processed; +} + +void OnlinePitchFeatureImpl::ExtractFrame( + const VectorBase &downsampled_wave_part, + int64 sample_index, + VectorBase *window) { + int32 full_frame_length = window->Dim(); + int32 offset = static_cast(sample_index - + downsampled_samples_processed_); + + // Treat edge cases first + if (sample_index < 0) { + // Part of the frame is before the beginning of the signal. This + // should only happen if opts_.snip_edges == false, when we are + // processing the first few frames of signal. In this case + // we pad with zeros. + KALDI_ASSERT(opts_.snip_edges == false); + int32 sub_frame_length = sample_index + full_frame_length; + int32 sub_frame_index = full_frame_length - sub_frame_length; + KALDI_ASSERT(sub_frame_length > 0 && sub_frame_index > 0); + window->SetZero(); + SubVector sub_window(*window, sub_frame_index, sub_frame_length); + ExtractFrame(downsampled_wave_part, 0, &sub_window); + return; + } + + if (offset + full_frame_length > downsampled_wave_part.Dim()) { + // Requested frame is past end of the signal. This should only happen if + // input_finished_ == true, when we're flushing out the last couple of + // frames of signal. In this case we pad with zeros. + KALDI_ASSERT(input_finished_); + int32 sub_frame_length = downsampled_wave_part.Dim() - offset; + KALDI_ASSERT(sub_frame_length > 0); + window->SetZero(); + SubVector sub_window(*window, 0, sub_frame_length); + ExtractFrame(downsampled_wave_part, sample_index, &sub_window); + return; + } + + // "offset" is the offset of the start of the frame, into this + // signal. + if (offset >= 0) { + // frame is full inside the new part of the signal. + window->CopyFromVec(downsampled_wave_part.Range(offset, full_frame_length)); + } else { + // frame is partly in the remainder and partly in the new part. + int32 remainder_offset = downsampled_signal_remainder_.Dim() + offset; + KALDI_ASSERT(remainder_offset >= 0); // or we didn't keep enough remainder. + KALDI_ASSERT(offset + full_frame_length > 0); // or we should have + // processed this frame last + // time. + + int32 old_length = -offset, new_length = offset + full_frame_length; + window->Range(0, old_length).CopyFromVec( + downsampled_signal_remainder_.Range(remainder_offset, old_length)); + window->Range(old_length, new_length).CopyFromVec( + downsampled_wave_part.Range(0, new_length)); + } + if (opts_.preemph_coeff != 0.0) { + BaseFloat preemph_coeff = opts_.preemph_coeff; + for (int32 i = window->Dim() - 1; i > 0; i--) + (*window)(i) -= preemph_coeff * (*window)(i-1); + (*window)(0) *= (1.0 - preemph_coeff); + } +} + +bool OnlinePitchFeatureImpl::IsLastFrame(int32 frame) const { + int32 T = NumFramesReady(); + KALDI_ASSERT(frame < T); + return (input_finished_ && frame + 1 == T); +} + +BaseFloat OnlinePitchFeatureImpl::FrameShiftInSeconds() const { + return opts_.frame_shift_ms / 1000.0f; +} + +int32 OnlinePitchFeatureImpl::NumFramesReady() const { + int32 num_frames = lag_nccf_.size(), + latency = frames_latency_; + KALDI_ASSERT(latency <= num_frames); + return num_frames - latency; +} + + +void OnlinePitchFeatureImpl::GetFrame(int32 frame, + VectorBase *feat) { + KALDI_ASSERT(frame < NumFramesReady() && feat->Dim() == 2); + (*feat)(0) = lag_nccf_[frame].second; + (*feat)(1) = 1.0 / lags_(lag_nccf_[frame].first); +} + +void OnlinePitchFeatureImpl::InputFinished() { + input_finished_ = true; + // Process an empty waveform; this has an effect because + // after setting input_finished_ to true, NumFramesAvailable() + // will return a slightly larger number. + AcceptWaveform(opts_.samp_freq, Vector()); + int32 num_frames = static_cast(frame_info_.size() - 1); + if (num_frames < opts_.recompute_frame && !opts_.nccf_ballast_online) + RecomputeBacktraces(); + frames_latency_ = 0; + KALDI_VLOG(3) << "Pitch-tracking Viterbi cost is " + << (forward_cost_remainder_ / num_frames) + << " per frame, over " << num_frames << " frames."; +} + +// see comment with declaration. This is only relevant for online +// operation (it gets called for non-online mode, but is a no-op). +void OnlinePitchFeatureImpl::RecomputeBacktraces() { + KALDI_ASSERT(!opts_.nccf_ballast_online); + int32 num_frames = static_cast(frame_info_.size()) - 1; + + // The assertion reflects how we believe this function will be called. + KALDI_ASSERT(num_frames <= opts_.recompute_frame); + KALDI_ASSERT(nccf_info_.size() == static_cast(num_frames)); + if (num_frames == 0) + return; + double num_samp = downsampled_samples_processed_, sum = signal_sum_, + sumsq = signal_sumsq_, mean = sum / num_samp; + BaseFloat mean_square = sumsq / num_samp - mean * mean; + + bool must_recompute = false; + BaseFloat threshold = 0.01; + for (int32 frame = 0; frame < num_frames; frame++) + if (!ApproxEqual(nccf_info_[frame]->mean_square_energy, + mean_square, threshold)) + must_recompute = true; + + if (!must_recompute) { + // Nothing to do. We'll reach here, for instance, if everything was in one + // chunk and opts_.nccf_ballast_online == false. This is the case for + // offline processing. + for (size_t i = 0; i < nccf_info_.size(); i++) + delete nccf_info_[i]; + nccf_info_.clear(); + return; + } + + int32 num_states = forward_cost_.Dim(), + basic_frame_length = opts_.NccfWindowSize(); + + BaseFloat new_nccf_ballast = pow(mean_square * basic_frame_length, 2) * + opts_.nccf_ballast; + + double forward_cost_remainder = 0.0; + Vector forward_cost(num_states), // start off at zero. + next_forward_cost(forward_cost); + std::vector > index_info; + + for (int32 frame = 0; frame < num_frames; frame++) { + NccfInfo &nccf_info = *nccf_info_[frame]; + BaseFloat old_mean_square = nccf_info_[frame]->mean_square_energy, + avg_norm_prod = nccf_info_[frame]->avg_norm_prod, + old_nccf_ballast = pow(old_mean_square * basic_frame_length, 2) * + opts_.nccf_ballast, + nccf_scale = pow((old_nccf_ballast + avg_norm_prod) / + (new_nccf_ballast + avg_norm_prod), + static_cast(0.5)); + // The "nccf_scale" is an estimate of the scaling factor by which the NCCF + // would change on this frame, on average, by changing the ballast term from + // "old_nccf_ballast" to "new_nccf_ballast". It's not exact because the + // "avg_norm_prod" is just an average of the product e1 * e2 of frame + // energies of the (frame, shifted-frame), but these won't change that much + // within a frame, and even if they do, the inaccuracy of the scaled NCCF + // will still be very small if the ballast term didn't change much, or if + // it's much larger or smaller than e1*e2. By doing it as a simple scaling, + // we save the overhead of the NCCF resampling, which is a considerable part + // of the whole computation. + nccf_info.nccf_pitch_resampled.Scale(nccf_scale); + + frame_info_[frame + 1]->ComputeBacktraces( + opts_, nccf_info.nccf_pitch_resampled, lags_, + forward_cost, &index_info, &next_forward_cost); + + forward_cost.Swap(&next_forward_cost); + BaseFloat remainder = forward_cost.Min(); + forward_cost_remainder += remainder; + forward_cost.Add(-remainder); + } + KALDI_VLOG(3) << "Forward-cost per frame changed from " + << (forward_cost_remainder_ / num_frames) << " to " + << (forward_cost_remainder / num_frames); + + forward_cost_remainder_ = forward_cost_remainder; + forward_cost_.Swap(&forward_cost); + + int32 best_final_state; + forward_cost_.Min(&best_final_state); + + if (lag_nccf_.size() != static_cast(num_frames)) + lag_nccf_.resize(num_frames); + + frame_info_.back()->SetBestState(best_final_state, lag_nccf_); + frames_latency_ = + frame_info_.back()->ComputeLatency(opts_.max_frames_latency); + for (size_t i = 0; i < nccf_info_.size(); i++) + delete nccf_info_[i]; + nccf_info_.clear(); +} + +OnlinePitchFeatureImpl::~OnlinePitchFeatureImpl() { + delete nccf_resampler_; + delete signal_resampler_; + for (size_t i = 0; i < frame_info_.size(); i++) + delete frame_info_[i]; + for (size_t i = 0; i < nccf_info_.size(); i++) + delete nccf_info_[i]; +} + +void OnlinePitchFeatureImpl::AcceptWaveform( + BaseFloat sampling_rate, + const VectorBase &wave) { + // flush out the last few samples of input waveform only if input_finished_ == + // true. + const bool flush = input_finished_; + + Vector downsampled_wave; + signal_resampler_->Resample(wave, flush, &downsampled_wave); + + // these variables will be used to compute the root-mean-square value of the + // signal for the ballast term. + double cur_sumsq = signal_sumsq_, cur_sum = signal_sum_; + int64 cur_num_samp = downsampled_samples_processed_, + prev_frame_end_sample = 0; + if (!opts_.nccf_ballast_online) { + cur_sumsq += VecVec(downsampled_wave, downsampled_wave); + cur_sum += downsampled_wave.Sum(); + cur_num_samp += downsampled_wave.Dim(); + } + + // end_frame is the total number of frames we can now process, including + // previously processed ones. + int32 end_frame = NumFramesAvailable( + downsampled_samples_processed_ + downsampled_wave.Dim(), opts_.snip_edges); + // "start_frame" is the first frame-index we process + int32 start_frame = frame_info_.size() - 1, + num_new_frames = end_frame - start_frame; + + if (num_new_frames == 0) { + UpdateRemainder(downsampled_wave); + return; + // continuing to the rest of the code would generate + // an error when sizing matrices with zero rows, and + // anyway is a waste of time. + } + + int32 num_measured_lags = nccf_last_lag_ + 1 - nccf_first_lag_, + num_resampled_lags = lags_.Dim(), + frame_shift = opts_.NccfWindowShift(), + basic_frame_length = opts_.NccfWindowSize(), + full_frame_length = basic_frame_length + nccf_last_lag_; + + Vector window(full_frame_length), + inner_prod(num_measured_lags), + norm_prod(num_measured_lags); + Matrix nccf_pitch(num_new_frames, num_measured_lags), + nccf_pov(num_new_frames, num_measured_lags); + + Vector cur_forward_cost(num_resampled_lags); + + + // Because the resampling of the NCCF is more efficient when grouped together, + // we first compute the NCCF for all frames, then resample as a matrix, then + // do the Viterbi [that happens inside the constructor of PitchFrameInfo]. + + for (int32 frame = start_frame; frame < end_frame; frame++) { + // start_sample is index into the whole wave, not just this part. + int64 start_sample; + if (opts_.snip_edges) { + // Usual case: offset starts at 0 + start_sample = static_cast(frame) * frame_shift; + } else { + // When we are not snipping the edges, the first offsets may be + // negative. In this case we will pad with zeros, it should not impact + // the pitch tracker. + start_sample = + static_cast((frame + 0.5) * frame_shift) - full_frame_length / 2; + } + ExtractFrame(downsampled_wave, start_sample, &window); + if (opts_.nccf_ballast_online) { + // use only up to end of current frame to compute root-mean-square value. + // end_sample will be the sample-index into "downsampled_wave", so + // not really comparable to start_sample. + int64 end_sample = start_sample + full_frame_length - + downsampled_samples_processed_; + KALDI_ASSERT(end_sample > 0); // or should have processed this frame last + // time. Note: end_sample is one past last + // sample. + if (end_sample > downsampled_wave.Dim()) { + KALDI_ASSERT(input_finished_); + end_sample = downsampled_wave.Dim(); + } + SubVector new_part(downsampled_wave, prev_frame_end_sample, + end_sample - prev_frame_end_sample); + cur_num_samp += new_part.Dim(); + cur_sumsq += VecVec(new_part, new_part); + cur_sum += new_part.Sum(); + prev_frame_end_sample = end_sample; + } + double mean_square = cur_sumsq / cur_num_samp - + pow(cur_sum / cur_num_samp, 2.0); + + ComputeCorrelation(window, nccf_first_lag_, nccf_last_lag_, + basic_frame_length, &inner_prod, &norm_prod); + double nccf_ballast_pov = 0.0, + nccf_ballast_pitch = pow(mean_square * basic_frame_length, 2) * + opts_.nccf_ballast, + avg_norm_prod = norm_prod.Sum() / norm_prod.Dim(); + SubVector nccf_pitch_row(nccf_pitch, frame - start_frame); + ComputeNccf(inner_prod, norm_prod, nccf_ballast_pitch, + &nccf_pitch_row); + SubVector nccf_pov_row(nccf_pov, frame - start_frame); + ComputeNccf(inner_prod, norm_prod, nccf_ballast_pov, + &nccf_pov_row); + if (frame < opts_.recompute_frame) + nccf_info_.push_back(new NccfInfo(avg_norm_prod, mean_square)); + } + + Matrix nccf_pitch_resampled(num_new_frames, num_resampled_lags); + nccf_resampler_->Resample(nccf_pitch, &nccf_pitch_resampled); + nccf_pitch.Resize(0, 0); // no longer needed. + Matrix nccf_pov_resampled(num_new_frames, num_resampled_lags); + nccf_resampler_->Resample(nccf_pov, &nccf_pov_resampled); + nccf_pov.Resize(0, 0); // no longer needed. + + // We've finished dealing with the waveform so we can call UpdateRemainder + // now; we need to call it before we possibly call RecomputeBacktraces() + // below, which is why we don't do it at the very end. + UpdateRemainder(downsampled_wave); + + std::vector > index_info; + + for (int32 frame = start_frame; frame < end_frame; frame++) { + int32 frame_idx = frame - start_frame; + PitchFrameInfo *prev_info = frame_info_.back(), + *cur_info = new PitchFrameInfo(prev_info); + cur_info->SetNccfPov(nccf_pov_resampled.Row(frame_idx)); + cur_info->ComputeBacktraces(opts_, nccf_pitch_resampled.Row(frame_idx), + lags_, forward_cost_, &index_info, + &cur_forward_cost); + forward_cost_.Swap(&cur_forward_cost); + // Renormalize forward_cost so smallest element is zero. + BaseFloat remainder = forward_cost_.Min(); + forward_cost_remainder_ += remainder; + forward_cost_.Add(-remainder); + frame_info_.push_back(cur_info); + if (frame < opts_.recompute_frame) + nccf_info_[frame]->nccf_pitch_resampled = + nccf_pitch_resampled.Row(frame_idx); + if (frame == opts_.recompute_frame - 1 && !opts_.nccf_ballast_online) + RecomputeBacktraces(); + } + + // Trace back the best-path. + int32 best_final_state; + forward_cost_.Min(&best_final_state); + lag_nccf_.resize(frame_info_.size() - 1); // will keep any existing data. + frame_info_.back()->SetBestState(best_final_state, lag_nccf_); + frames_latency_ = + frame_info_.back()->ComputeLatency(opts_.max_frames_latency); + KALDI_VLOG(4) << "Latency is " << frames_latency_; +} + + + +// Some functions that forward from OnlinePitchFeature to +// OnlinePitchFeatureImpl. +int32 OnlinePitchFeature::NumFramesReady() const { + return impl_->NumFramesReady(); +} + +OnlinePitchFeature::OnlinePitchFeature(const PitchExtractionOptions &opts) + :impl_(new OnlinePitchFeatureImpl(opts)) { } + +bool OnlinePitchFeature::IsLastFrame(int32 frame) const { + return impl_->IsLastFrame(frame); +} + +BaseFloat OnlinePitchFeature::FrameShiftInSeconds() const { + return impl_->FrameShiftInSeconds(); +} + +void OnlinePitchFeature::GetFrame(int32 frame, VectorBase *feat) { + impl_->GetFrame(frame, feat); +} + +void OnlinePitchFeature::AcceptWaveform( + BaseFloat sampling_rate, + const VectorBase &waveform) { + impl_->AcceptWaveform(sampling_rate, waveform); +} + +void OnlinePitchFeature::InputFinished() { + impl_->InputFinished(); +} + +OnlinePitchFeature::~OnlinePitchFeature() { + delete impl_; +} + + +/** + This function is called from ComputeKaldiPitch when the user + specifies opts.simulate_first_pass_online == true. It gives + the "first-pass" version of the features, which you would get + on the first decoding pass in an online setting. These may + differ slightly from the final features due to both the + way the Viterbi traceback works (this is affected by + opts.max_frames_latency), and the online way we compute + the average signal energy. +*/ +void ComputeKaldiPitchFirstPass( + const PitchExtractionOptions &opts, + const VectorBase &wave, + Matrix *output) { + + int32 cur_rows = 100; + Matrix feats(cur_rows, 2); + + OnlinePitchFeature pitch_extractor(opts); + KALDI_ASSERT(opts.frames_per_chunk > 0 && + "--simulate-first-pass-online option does not make sense " + "unless you specify --frames-per-chunk"); + + int32 cur_offset = 0, cur_frame = 0, samp_per_chunk = + opts.frames_per_chunk * opts.samp_freq * opts.frame_shift_ms / 1000.0f; + + while (cur_offset < wave.Dim()) { + int32 num_samp = std::min(samp_per_chunk, wave.Dim() - cur_offset); + SubVector wave_chunk(wave, cur_offset, num_samp); + pitch_extractor.AcceptWaveform(opts.samp_freq, wave_chunk); + cur_offset += num_samp; + if (cur_offset == wave.Dim()) + pitch_extractor.InputFinished(); + // Get each frame as soon as it is ready. + for (; cur_frame < pitch_extractor.NumFramesReady(); cur_frame++) { + if (cur_frame >= cur_rows) { + cur_rows *= 2; + feats.Resize(cur_rows, 2, kCopyData); + } + SubVector row(feats, cur_frame); + pitch_extractor.GetFrame(cur_frame, &row); + } + } + if (cur_frame == 0) { + KALDI_WARN << "No features output since wave file too short"; + output->Resize(0, 0); + } else { + *output = feats.RowRange(0, cur_frame); + } +} + + + +void ComputeKaldiPitch(const PitchExtractionOptions &opts, + const VectorBase &wave, + Matrix *output) { + if (opts.simulate_first_pass_online) { + ComputeKaldiPitchFirstPass(opts, wave, output); + return; + } + OnlinePitchFeature pitch_extractor(opts); + + if (opts.frames_per_chunk == 0) { + pitch_extractor.AcceptWaveform(opts.samp_freq, wave); + } else { + // the user may set opts.frames_per_chunk for better compatibility with + // online operation. + KALDI_ASSERT(opts.frames_per_chunk > 0); + int32 cur_offset = 0, samp_per_chunk = + opts.frames_per_chunk * opts.samp_freq * opts.frame_shift_ms / 1000.0f; + while (cur_offset < wave.Dim()) { + int32 num_samp = std::min(samp_per_chunk, wave.Dim() - cur_offset); + SubVector wave_chunk(wave, cur_offset, num_samp); + pitch_extractor.AcceptWaveform(opts.samp_freq, wave_chunk); + cur_offset += num_samp; + } + } + pitch_extractor.InputFinished(); + int32 num_frames = pitch_extractor.NumFramesReady(); + if (num_frames == 0) { + KALDI_WARN << "No frames output in pitch extraction"; + output->Resize(0, 0); + return; + } + output->Resize(num_frames, 2); + for (int32 frame = 0; frame < num_frames; frame++) { + SubVector row(*output, frame); + pitch_extractor.GetFrame(frame, &row); + } +} + + +/* + This comment describes our invesigation of how much latency the + online-processing algorithm introduces, i.e. how many frames you would + typically have to wait until the traceback converges, if you were to set the + --max-frames-latency to a very large value. + + This was done on a couple of files of language-id data. + + /home/dpovey/kaldi-online/src/featbin/compute-kaldi-pitch-feats --frames-per-chunk=10 --max-frames-latency=100 --verbose=4 --sample-frequency=8000 --resample-frequency=2600 "scp:head -n 2 data/train/wav.scp |" ark:/dev/null 2>&1 | grep Latency | wc + 4871 24355 443991 + /home/dpovey/kaldi-online/src/featbin/compute-kaldi-pitch-feats --frames-per-chunk=10 --max-frames-latency=100 --verbose=4 --sample-frequency=8000 --resample-frequency=2600 "scp:head -n 2 data/train/wav.scp |" ark:/dev/null 2>&1 | grep Latency | grep 100 | wc + 1534 7670 141128 + +# as above, but with 50 instead of 10 in the --max-frames-latency and grep statements. + 2070 10350 188370 +# as above, but with 10 instead of 50. + 4067 20335 370097 + + This says that out of 4871 selected frames [we measured the latency every 10 + frames, since --frames-per-chunk=10], in 1534 frames (31%), the latency was + >= 100 frames, i.e. >= 1 second. Including the other numbers, we can see + that + + 31% of frames had latency >= 1 second + 42% of frames had latency >= 0.5 second + 83% of frames had latency >= 0.1 second. + + This doesn't necessarily mean that we actually have a latency of >= 1 second 31% of + the time when using these features, since by using the --max-frames-latency option + (default: 30 frames), it will limit the latency to, say, 0.3 seconds, and trace back + from the best current pitch. Most of the time this will probably cause no change in + the pitch traceback since the best current pitch is probably the "right" point to + trace back from. And anyway, in the online-decoding, we will most likely rescore + the features at the end anyway, and the traceback gets recomputed, so there will + be no inaccuracy (assuming the first-pass lattice had everything we needed). + + Probably the greater source of inaccuracy due to the online algorithm is the + online energy-normalization, which affects the NCCF-ballast term, and which, + for reasons of efficiency, we don't attempt to "correct" in a later rescoring + pass. This will make the most difference in the first few frames of the file, + before the first voicing, where it will tend to produce more pitch movement + than the offline version of the algorithm. +*/ + + +// Function to do data accumulation for on-line usage +template +inline void AppendVector(const VectorBase &src, Vector *dst) { + if (src.Dim() == 0) return; + dst->Resize(dst->Dim() + src.Dim(), kCopyData); + dst->Range(dst->Dim() - src.Dim(), src.Dim()).CopyFromVec(src); +} + +/** + Note on the implementation of OnlineProcessPitch: the + OnlineFeatureInterface allows random access to features (i.e. not necessarily + sequential order), so we need to support that. But we don't need to support + it very efficiently, and our implementation is most efficient if frames are + accessed in sequential order. + + Also note: we have to be a bit careful in this implementation because + the input features may change. That is: if we call + src_->GetFrame(t, &vec) from GetFrame(), we can't guarantee that a later + call to src_->GetFrame(t, &vec) from another GetFrame() will return the + same value. In fact, while designing this class we used some knowledge + of how the OnlinePitchFeature class works to minimize the amount of + re-querying we had to do. +*/ +OnlineProcessPitch::OnlineProcessPitch( + const ProcessPitchOptions &opts, + OnlineFeatureInterface *src): + opts_(opts), src_(src), + dim_ ((opts.add_pov_feature ? 1 : 0) + + (opts.add_normalized_log_pitch ? 1 : 0) + + (opts.add_delta_pitch ? 1 : 0) + + (opts.add_raw_log_pitch ? 1 : 0)) { + KALDI_ASSERT(dim_ > 0 && + " At least one of the pitch features should be chosen. " + "Check your post-process-pitch options."); + KALDI_ASSERT(src->Dim() == kRawFeatureDim && + "Input feature must be pitch feature (should have dimension 2)"); +} + + +void OnlineProcessPitch::GetFrame(int32 frame, + VectorBase *feat) { + int32 frame_delayed = frame < opts_.delay ? 0 : frame - opts_.delay; + KALDI_ASSERT(feat->Dim() == dim_ && + frame_delayed < NumFramesReady()); + int32 index = 0; + if (opts_.add_pov_feature) + (*feat)(index++) = GetPovFeature(frame_delayed); + if (opts_.add_normalized_log_pitch) + (*feat)(index++) = GetNormalizedLogPitchFeature(frame_delayed); + if (opts_.add_delta_pitch) + (*feat)(index++) = GetDeltaPitchFeature(frame_delayed); + if (opts_.add_raw_log_pitch) + (*feat)(index++) = GetRawLogPitchFeature(frame_delayed); + KALDI_ASSERT(index == dim_); +} + +BaseFloat OnlineProcessPitch::GetPovFeature(int32 frame) const { + Vector tmp(kRawFeatureDim); + src_->GetFrame(frame, &tmp); // (NCCF, pitch) from pitch extractor + BaseFloat nccf = tmp(0); + return opts_.pov_scale * NccfToPovFeature(nccf) + + opts_.pov_offset; +} + +BaseFloat OnlineProcessPitch::GetDeltaPitchFeature(int32 frame) { + // Rather than computing the delta pitch directly in code here, + // which might seem easier, we accumulate a small window of features + // and call ComputeDeltas. This might seem like overkill; the reason + // we do it this way is to ensure that the end effects (at file + // beginning and end) are handled in a consistent way. + int32 context = opts_.delta_window; + int32 start_frame = std::max(0, frame - context), + end_frame = std::min(frame + context + 1, src_->NumFramesReady()), + frames_in_window = end_frame - start_frame; + Matrix feats(frames_in_window, 1), + delta_feats; + + for (int32 f = start_frame; f < end_frame; f++) + feats(f - start_frame, 0) = GetRawLogPitchFeature(f); + + DeltaFeaturesOptions delta_opts; + delta_opts.order = 1; + delta_opts.window = opts_.delta_window; + ComputeDeltas(delta_opts, feats, &delta_feats); + while (delta_feature_noise_.size() <= static_cast(frame)) { + delta_feature_noise_.push_back(RandGauss() * + opts_.delta_pitch_noise_stddev); + } + // note: delta_feats will have two columns, second contains deltas. + return (delta_feats(frame - start_frame, 1) + delta_feature_noise_[frame]) * + opts_.delta_pitch_scale; +} + +BaseFloat OnlineProcessPitch::GetRawLogPitchFeature(int32 frame) const { + Vector tmp(kRawFeatureDim); + src_->GetFrame(frame, &tmp); + BaseFloat pitch = tmp(1); + KALDI_ASSERT(pitch > 0); + return Log(pitch); +} + +BaseFloat OnlineProcessPitch::GetNormalizedLogPitchFeature(int32 frame) { + UpdateNormalizationStats(frame); + BaseFloat log_pitch = GetRawLogPitchFeature(frame), + avg_log_pitch = normalization_stats_[frame].sum_log_pitch_pov / + normalization_stats_[frame].sum_pov, + normalized_log_pitch = log_pitch - avg_log_pitch; + return normalized_log_pitch * opts_.pitch_scale; +} + + +// inline +void OnlineProcessPitch::GetNormalizationWindow(int32 t, + int32 src_frames_ready, + int32 *window_begin, + int32 *window_end) const { + int32 left_context = opts_.normalization_left_context; + int32 right_context = opts_.normalization_right_context; + *window_begin = std::max(0, t - left_context); + *window_end = std::min(t + right_context + 1, src_frames_ready); +} + + +// Makes sure the entry in normalization_stats_ for this frame is up to date; +// called from GetNormalizedLogPitchFeature. +// the cur_num_frames and input_finished variables are needed because the +// pitch features for a given frame may change as we see more data. +void OnlineProcessPitch::UpdateNormalizationStats(int32 frame) { + KALDI_ASSERT(frame >= 0); + if (normalization_stats_.size() <= frame) + normalization_stats_.resize(frame + 1); + int32 cur_num_frames = src_->NumFramesReady(); + bool input_finished = src_->IsLastFrame(cur_num_frames - 1); + + NormalizationStats &this_stats = normalization_stats_[frame]; + if (this_stats.cur_num_frames == cur_num_frames && + this_stats.input_finished == input_finished) { + // Stats are fully up-to-date. + return; + } + int32 this_window_begin, this_window_end; + GetNormalizationWindow(frame, cur_num_frames, + &this_window_begin, &this_window_end); + + if (frame > 0) { + const NormalizationStats &prev_stats = normalization_stats_[frame - 1]; + if (prev_stats.cur_num_frames == cur_num_frames && + prev_stats.input_finished == input_finished) { + // we'll derive this_stats efficiently from prev_stats. + // Checking that cur_num_frames and input_finished have not changed + // ensures that the underlying features will not have changed. + this_stats = prev_stats; + int32 prev_window_begin, prev_window_end; + GetNormalizationWindow(frame - 1, cur_num_frames, + &prev_window_begin, &prev_window_end); + if (this_window_begin != prev_window_begin) { + KALDI_ASSERT(this_window_begin == prev_window_begin + 1); + Vector tmp(kRawFeatureDim); + src_->GetFrame(prev_window_begin, &tmp); + BaseFloat accurate_pov = NccfToPov(tmp(0)), + log_pitch = Log(tmp(1)); + this_stats.sum_pov -= accurate_pov; + this_stats.sum_log_pitch_pov -= accurate_pov * log_pitch; + } + if (this_window_end != prev_window_end) { + KALDI_ASSERT(this_window_end == prev_window_end + 1); + Vector tmp(kRawFeatureDim); + src_->GetFrame(prev_window_end, &tmp); + BaseFloat accurate_pov = NccfToPov(tmp(0)), + log_pitch = Log(tmp(1)); + this_stats.sum_pov += accurate_pov; + this_stats.sum_log_pitch_pov += accurate_pov * log_pitch; + } + return; + } + } + // The way we do it here is not the most efficient way to do it; + // we'll see if it becomes a problem. The issue is we have to redo + // this computation from scratch each time we process a new chunk, which + // may be a little inefficient if the chunk-size is very small. + this_stats.cur_num_frames = cur_num_frames; + this_stats.input_finished = input_finished; + this_stats.sum_pov = 0.0; + this_stats.sum_log_pitch_pov = 0.0; + Vector tmp(kRawFeatureDim); + for (int32 f = this_window_begin; f < this_window_end; f++) { + src_->GetFrame(f, &tmp); + BaseFloat accurate_pov = NccfToPov(tmp(0)), + log_pitch = Log(tmp(1)); + this_stats.sum_pov += accurate_pov; + this_stats.sum_log_pitch_pov += accurate_pov * log_pitch; + } +} + +int32 OnlineProcessPitch::NumFramesReady() const { + int32 src_frames_ready = src_->NumFramesReady(); + if (src_frames_ready == 0) { + return 0; + } else if (src_->IsLastFrame(src_frames_ready - 1)) { + return src_frames_ready + opts_.delay; + } else { + return std::max(0, src_frames_ready - + opts_.normalization_right_context + opts_.delay); + } +} + +void ProcessPitch(const ProcessPitchOptions &opts, + const MatrixBase &input, + Matrix *output) { + OnlineMatrixFeature pitch_feat(input); + + OnlineProcessPitch online_process_pitch(opts, &pitch_feat); + + output->Resize(online_process_pitch.NumFramesReady(), + online_process_pitch.Dim()); + for (int32 t = 0; t < online_process_pitch.NumFramesReady(); t++) { + SubVector row(*output, t); + online_process_pitch.GetFrame(t, &row); + } +} + + +void ComputeAndProcessKaldiPitch( + const PitchExtractionOptions &pitch_opts, + const ProcessPitchOptions &process_opts, + const VectorBase &wave, + Matrix *output) { + + OnlinePitchFeature pitch_extractor(pitch_opts); + + if (pitch_opts.simulate_first_pass_online) { + KALDI_ASSERT(pitch_opts.frames_per_chunk > 0 && + "--simulate-first-pass-online option does not make sense " + "unless you specify --frames-per-chunk"); + } + + OnlineProcessPitch post_process(process_opts, &pitch_extractor); + + int32 cur_rows = 100; + Matrix feats(cur_rows, post_process.Dim()); + + int32 cur_offset = 0, cur_frame = 0, + samp_per_chunk = pitch_opts.frames_per_chunk * + pitch_opts.samp_freq * pitch_opts.frame_shift_ms / 1000.0f; + + // We request the first-pass features as soon as they are available, + // regardless of whether opts.simulate_first_pass_online == true. If + // opts.simulate_first_pass_online == true this should + // not affect the features generated, but it helps us to test the code + // in a way that's closer to what online decoding would see. + + while (cur_offset < wave.Dim()) { + int32 num_samp; + if (samp_per_chunk > 0) + num_samp = std::min(samp_per_chunk, wave.Dim() - cur_offset); + else // user left opts.frames_per_chunk at zero. + num_samp = wave.Dim(); + SubVector wave_chunk(wave, cur_offset, num_samp); + pitch_extractor.AcceptWaveform(pitch_opts.samp_freq, wave_chunk); + cur_offset += num_samp; + if (cur_offset == wave.Dim()) + pitch_extractor.InputFinished(); + + // Get each frame as soon as it is ready. + for (; cur_frame < post_process.NumFramesReady(); cur_frame++) { + if (cur_frame >= cur_rows) { + cur_rows *= 2; + feats.Resize(cur_rows, post_process.Dim(), kCopyData); + } + SubVector row(feats, cur_frame); + post_process.GetFrame(cur_frame, &row); + } + } + + if (pitch_opts.simulate_first_pass_online) { + if (cur_frame == 0) { + KALDI_WARN << "No features output since wave file too short"; + output->Resize(0, 0); + } else { + *output = feats.RowRange(0, cur_frame); + } + } else { + // want the "final" features for second pass, so get them again. + output->Resize(post_process.NumFramesReady(), post_process.Dim()); + for (int32 frame = 0; frame < post_process.NumFramesReady(); frame++) { + SubVector row(*output, frame); + post_process.GetFrame(frame, &row); + } + } +} + + +} // namespace kaldi diff --git a/torchaudio/csrc/kaldi/feat/pitch-functions.h b/torchaudio/csrc/kaldi/feat/pitch-functions.h new file mode 100644 index 00000000000..6419ccaa24b --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/pitch-functions.h @@ -0,0 +1,450 @@ +// feat/pitch-functions.h + +// Copyright 2013 Pegah Ghahremani +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2014 Yanqing Sun, Junjie Wang, +// Daniel Povey, Korbinian Riedhammer +// Xin Lei + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_FEAT_PITCH_FUNCTIONS_H_ +#define KALDI_FEAT_PITCH_FUNCTIONS_H_ + +#include +#include +#include +#include + +#include "base/kaldi-error.h" +#include "feat/mel-computations.h" +#include "itf/online-feature-itf.h" +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + +struct PitchExtractionOptions { + // FrameExtractionOptions frame_opts; + BaseFloat samp_freq; // sample frequency in hertz + BaseFloat frame_shift_ms; // in milliseconds. + BaseFloat frame_length_ms; // in milliseconds. + BaseFloat preemph_coeff; // Preemphasis coefficient. [use is deprecated.] + BaseFloat min_f0; // min f0 to search (Hz) + BaseFloat max_f0; // max f0 to search (Hz) + BaseFloat soft_min_f0; // Minimum f0, applied in soft way, must not + // exceed min-f0 + BaseFloat penalty_factor; // cost factor for FO change + BaseFloat lowpass_cutoff; // cutoff frequency for Low pass filter + BaseFloat resample_freq; // Integer that determines filter width when + // upsampling NCCF + BaseFloat delta_pitch; // the pitch tolerance in pruning lags + BaseFloat nccf_ballast; // Increasing this factor reduces NCCF for + // quiet frames, helping ensure pitch + // continuity in unvoiced region + int32 lowpass_filter_width; // Integer that determines filter width of + // lowpass filter + int32 upsample_filter_width; // Integer that determines filter width when + // upsampling NCCF + + // Below are newer config variables, not present in the original paper, + // that relate to the online pitch extraction algorithm. + + // The maximum number of frames of latency that we allow the pitch-processing + // to introduce, for online operation. If you set this to a large value, + // there would be no inaccuracy from the Viterbi traceback (but it might make + // you wait to see the pitch). This is not very relevant for the online + // operation: normalization-right-context is more relevant, you + // can just leave this value at zero. + int32 max_frames_latency; + + // Only relevant for the function ComputeKaldiPitch which is called by + // compute-kaldi-pitch-feats. If nonzero, we provide the input as chunks of + // this size. This affects the energy normalization which has a small effect + // on the resulting features, especially at the beginning of a file. For best + // compatibility with online operation (e.g. if you plan to train models for + // the online-deocding setup), you might want to set this to a small value, + // like one frame. + int32 frames_per_chunk; + + // Only relevant for the function ComputeKaldiPitch which is called by + // compute-kaldi-pitch-feats, and only relevant if frames_per_chunk is + // nonzero. If true, it will query the features as soon as they are + // available, which simulates the first-pass features you would get in online + // decoding. If false, the features you will get will be the same as those + // available at the end of the utterance, after InputFinished() has been + // called: e.g. during lattice rescoring. + bool simulate_first_pass_online; + + // Only relevant for online operation or when emulating online operation + // (e.g. when setting frames_per_chunk). This is the frame-index on which we + // recompute the NCCF (e.g. frame-index 500 = after 5 seconds); if the + // segment ends before this we do it when the segment ends. We do this by + // re-computing the signal average energy, which affects the NCCF via the + // "ballast term", scaling the resampled NCCF by a factor derived from the + // average change in the "ballast term", and re-doing the backtrace + // computation. Making this infinity would be the most exact, but would + // introduce unwanted latency at the end of long utterances, for little + // benefit. + int32 recompute_frame; + + // This is a "hidden config" used only for testing the online pitch + // extraction. If true, we compute the signal root-mean-squared for the + // ballast term, only up to the current frame, rather than the end of the + // current chunk of signal. This makes the output insensitive to the + // chunking, which is useful for testing purposes. + bool nccf_ballast_online; + bool snip_edges; + PitchExtractionOptions(): + samp_freq(16000), + frame_shift_ms(10.0), + frame_length_ms(25.0), + preemph_coeff(0.0), + min_f0(50), + max_f0(400), + soft_min_f0(10.0), + penalty_factor(0.1), + lowpass_cutoff(1000), + resample_freq(4000), + delta_pitch(0.005), + nccf_ballast(7000), + lowpass_filter_width(1), + upsample_filter_width(5), + max_frames_latency(0), + frames_per_chunk(0), + simulate_first_pass_online(false), + recompute_frame(500), + nccf_ballast_online(false), + snip_edges(true) { } + + void Register(OptionsItf *opts) { + opts->Register("sample-frequency", &samp_freq, + "Waveform data sample frequency (must match the waveform " + "file, if specified there)"); + opts->Register("frame-length", &frame_length_ms, "Frame length in " + "milliseconds"); + opts->Register("frame-shift", &frame_shift_ms, "Frame shift in " + "milliseconds"); + opts->Register("preemphasis-coefficient", &preemph_coeff, + "Coefficient for use in signal preemphasis (deprecated)"); + opts->Register("min-f0", &min_f0, + "min. F0 to search for (Hz)"); + opts->Register("max-f0", &max_f0, + "max. F0 to search for (Hz)"); + opts->Register("soft-min-f0", &soft_min_f0, + "Minimum f0, applied in soft way, must not exceed min-f0"); + opts->Register("penalty-factor", &penalty_factor, + "cost factor for FO change."); + opts->Register("lowpass-cutoff", &lowpass_cutoff, + "cutoff frequency for LowPass filter (Hz) "); + opts->Register("resample-frequency", &resample_freq, + "Frequency that we down-sample the signal to. Must be " + "more than twice lowpass-cutoff"); + opts->Register("delta-pitch", &delta_pitch, + "Smallest relative change in pitch that our algorithm " + "measures"); + opts->Register("nccf-ballast", &nccf_ballast, + "Increasing this factor reduces NCCF for quiet frames"); + opts->Register("nccf-ballast-online", &nccf_ballast_online, + "This is useful mainly for debug; it affects how the NCCF " + "ballast is computed."); + opts->Register("lowpass-filter-width", &lowpass_filter_width, + "Integer that determines filter width of " + "lowpass filter, more gives sharper filter"); + opts->Register("upsample-filter-width", &upsample_filter_width, + "Integer that determines filter width when upsampling NCCF"); + opts->Register("frames-per-chunk", &frames_per_chunk, "Only relevant for " + "offline pitch extraction (e.g. compute-kaldi-pitch-feats), " + "you can set it to a small nonzero value, such as 10, for " + "better feature compatibility with online decoding (affects " + "energy normalization in the algorithm)"); + opts->Register("simulate-first-pass-online", &simulate_first_pass_online, + "If true, compute-kaldi-pitch-feats will output features " + "that correspond to what an online decoder would see in the " + "first pass of decoding-- not the final version of the " + "features, which is the default. Relevant if " + "--frames-per-chunk > 0"); + opts->Register("recompute-frame", &recompute_frame, "Only relevant for " + "online pitch extraction, or for compatibility with online " + "pitch extraction. A non-critical parameter; the frame at " + "which we recompute some of the forward pointers, after " + "revising our estimate of the signal energy. Relevant if" + "--frames-per-chunk > 0"); + opts->Register("max-frames-latency", &max_frames_latency, "Maximum number " + "of frames of latency that we allow pitch tracking to " + "introduce into the feature processing (affects output only " + "if --frames-per-chunk > 0 and " + "--simulate-first-pass-online=true"); + opts->Register("snip-edges", &snip_edges, "If this is set to false, the " + "incomplete frames near the ending edge won't be snipped, " + "so that the number of frames is the file size divided by " + "the frame-shift. This makes different types of features " + "give the same number of frames."); + } + /// Returns the window-size in samples, after resampling. This is the + /// "basic window size", not the full window size after extending by max-lag. + // Because of floating point representation, it is more reliable to divide + // by 1000 instead of multiplying by 0.001, but it is a bit slower. + int32 NccfWindowSize() const { + return static_cast(resample_freq * frame_length_ms / 1000.0); + } + /// Returns the window-shift in samples, after resampling. + int32 NccfWindowShift() const { + return static_cast(resample_freq * frame_shift_ms / 1000.0); + } +}; + +struct ProcessPitchOptions { + BaseFloat pitch_scale; // the final normalized-log-pitch feature is scaled + // with this value + BaseFloat pov_scale; // the final POV feature is scaled with this value + BaseFloat pov_offset; // An offset that can be added to the final POV + // feature (useful for online-decoding, where we don't + // do CMN to the pitch-derived features. + + BaseFloat delta_pitch_scale; + BaseFloat delta_pitch_noise_stddev; // stddev of noise we add to delta-pitch + int32 normalization_left_context; // left-context used for sliding-window + // normalization + int32 normalization_right_context; // this should be reduced in online + // decoding to reduce latency + + int32 delta_window; + int32 delay; + + bool add_pov_feature; + bool add_normalized_log_pitch; + bool add_delta_pitch; + bool add_raw_log_pitch; + + ProcessPitchOptions() : + pitch_scale(2.0), + pov_scale(2.0), + pov_offset(0.0), + delta_pitch_scale(10.0), + delta_pitch_noise_stddev(0.005), + normalization_left_context(75), + normalization_right_context(75), + delta_window(2), + delay(0), + add_pov_feature(true), + add_normalized_log_pitch(true), + add_delta_pitch(true), + add_raw_log_pitch(false) { } + + + void Register(OptionsItf *opts) { + opts->Register("pitch-scale", &pitch_scale, + "Scaling factor for the final normalized log-pitch value"); + opts->Register("pov-scale", &pov_scale, + "Scaling factor for final POV (probability of voicing) " + "feature"); + opts->Register("pov-offset", &pov_offset, + "This can be used to add an offset to the POV feature. " + "Intended for use in online decoding as a substitute for " + " CMN."); + opts->Register("delta-pitch-scale", &delta_pitch_scale, + "Term to scale the final delta log-pitch feature"); + opts->Register("delta-pitch-noise-stddev", &delta_pitch_noise_stddev, + "Standard deviation for noise we add to the delta log-pitch " + "(before scaling); should be about the same as delta-pitch " + "option to pitch creation. The purpose is to get rid of " + "peaks in the delta-pitch caused by discretization of pitch " + "values."); + opts->Register("normalization-left-context", &normalization_left_context, + "Left-context (in frames) for moving window normalization"); + opts->Register("normalization-right-context", &normalization_right_context, + "Right-context (in frames) for moving window normalization"); + opts->Register("delta-window", &delta_window, + "Number of frames on each side of central frame, to use for " + "delta window."); + opts->Register("delay", &delay, + "Number of frames by which the pitch information is " + "delayed."); + opts->Register("add-pov-feature", &add_pov_feature, + "If true, the warped NCCF is added to output features"); + opts->Register("add-normalized-log-pitch", &add_normalized_log_pitch, + "If true, the log-pitch with POV-weighted mean subtraction " + "over 1.5 second window is added to output features"); + opts->Register("add-delta-pitch", &add_delta_pitch, + "If true, time derivative of log-pitch is added to output " + "features"); + opts->Register("add-raw-log-pitch", &add_raw_log_pitch, + "If true, log(pitch) is added to output features"); + } +}; + + +// We don't want to expose the pitch-extraction internals here as it's +// quite complex, so we use a private implementation. +class OnlinePitchFeatureImpl; + + +// Note: to start on a new waveform, just construct a new version +// of this object. +class OnlinePitchFeature: public OnlineBaseFeature { + public: + explicit OnlinePitchFeature(const PitchExtractionOptions &opts); + + virtual int32 Dim() const { return 2; /* (NCCF, pitch) */ } + + virtual int32 NumFramesReady() const; + + virtual BaseFloat FrameShiftInSeconds() const; + + virtual bool IsLastFrame(int32 frame) const; + + /// Outputs the two-dimensional feature consisting of (pitch, NCCF). You + /// should probably post-process this using class OnlineProcessPitch. + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual void AcceptWaveform(BaseFloat sampling_rate, + const VectorBase &waveform); + + virtual void InputFinished(); + + virtual ~OnlinePitchFeature(); + + private: + OnlinePitchFeatureImpl *impl_; +}; + + +/// This online-feature class implements post processing of pitch features. +/// Inputs are original 2 dims (nccf, pitch). It can produce various +/// kinds of outputs, using the default options it will be (pov-feature, +/// normalized-log-pitch, delta-log-pitch). +class OnlineProcessPitch: public OnlineFeatureInterface { + public: + virtual int32 Dim() const { return dim_; } + + virtual bool IsLastFrame(int32 frame) const { + if (frame <= -1) + return src_->IsLastFrame(-1); + else if (frame < opts_.delay) + return src_->IsLastFrame(-1) == true ? false : src_->IsLastFrame(0); + else + return src_->IsLastFrame(frame - opts_.delay); + } + virtual BaseFloat FrameShiftInSeconds() const { + return src_->FrameShiftInSeconds(); + } + + virtual int32 NumFramesReady() const; + + virtual void GetFrame(int32 frame, VectorBase *feat); + + virtual ~OnlineProcessPitch() { } + + // Does not take ownership of "src". + OnlineProcessPitch(const ProcessPitchOptions &opts, + OnlineFeatureInterface *src); + + private: + enum { kRawFeatureDim = 2}; // anonymous enum to define a constant. + // kRawFeatureDim defines the dimension + // of the input: (nccf, pitch) + + ProcessPitchOptions opts_; + OnlineFeatureInterface *src_; + int32 dim_; // Output feature dimension, set in initializer. + + struct NormalizationStats { + int32 cur_num_frames; // value of src_->NumFramesReady() when + // "mean_pitch" was set. + bool input_finished; // true if input data was finished when + // "mean_pitch" was computed. + double sum_pov; // sum of pov over relevant range + double sum_log_pitch_pov; // sum of log(pitch) * pov over relevant range + + NormalizationStats(): cur_num_frames(-1), input_finished(false), + sum_pov(0.0), sum_log_pitch_pov(0.0) { } + }; + + std::vector delta_feature_noise_; + + std::vector normalization_stats_; + + /// Computes and returns the POV feature for this frame. + /// Called from GetFrame(). + inline BaseFloat GetPovFeature(int32 frame) const; + + /// Computes and returns the delta-log-pitch feature for this frame. + /// Called from GetFrame(). + inline BaseFloat GetDeltaPitchFeature(int32 frame); + + /// Computes and returns the raw log-pitch feature for this frame. + /// Called from GetFrame(). + inline BaseFloat GetRawLogPitchFeature(int32 frame) const; + + /// Computes and returns the mean-subtracted log-pitch feature for this frame. + /// Called from GetFrame(). + inline BaseFloat GetNormalizedLogPitchFeature(int32 frame); + + /// Computes the normalization window sizes. + inline void GetNormalizationWindow(int32 frame, + int32 src_frames_ready, + int32 *window_begin, + int32 *window_end) const; + + /// Makes sure the entry in normalization_stats_ for this frame is up to date; + /// called from GetNormalizedLogPitchFeature. + inline void UpdateNormalizationStats(int32 frame); +}; + + +/// This function extracts (pitch, NCCF) per frame, using the pitch extraction +/// method described in "A Pitch Extraction Algorithm Tuned for Automatic Speech +/// Recognition", Pegah Ghahremani, Bagher BabaAli, Daniel Povey, Korbinian +/// Riedhammer, Jan Trmal and Sanjeev Khudanpur, ICASSP 2014. The output will +/// have as many rows as there are frames, and two columns corresponding to +/// (NCCF, pitch) +void ComputeKaldiPitch(const PitchExtractionOptions &opts, + const VectorBase &wave, + Matrix *output); + +/// This function processes the raw (NCCF, pitch) quantities computed by +/// ComputeKaldiPitch, and processes them into features. By default it will +/// output three-dimensional features, (POV-feature, mean-subtracted-log-pitch, +/// delta-of-raw-pitch), but this is configurable in the options. The number of +/// rows of "output" will be the number of frames (rows) in "input", and the +/// number of columns will be the number of different types of features +/// requested (by default, 3; 4 is the max). The four config variables +/// --add-pov-feature, --add-normalized-log-pitch, --add-delta-pitch, +/// --add-raw-log-pitch determine which features we create; by default we create +/// the first three. +void ProcessPitch(const ProcessPitchOptions &opts, + const MatrixBase &input, + Matrix *output); + +/// This function combines ComputeKaldiPitch and ProcessPitch. The reason +/// why we need a separate function to do this is in order to be able to +/// accurately simulate the online pitch-processing, for testing and for +/// training models matched to the "first-pass" features. It is sensitive to +/// the variables in pitch_opts that relate to online processing, +/// i.e. max_frames_latency, frames_per_chunk, simulate_first_pass_online, +/// recompute_frame. +void ComputeAndProcessKaldiPitch(const PitchExtractionOptions &pitch_opts, + const ProcessPitchOptions &process_opts, + const VectorBase &wave, + Matrix *output); + + +/// @} End of "addtogroup feat" +} // namespace kaldi +#endif // KALDI_FEAT_PITCH_FUNCTIONS_H_ diff --git a/torchaudio/csrc/kaldi/feat/resample.cc b/torchaudio/csrc/kaldi/feat/resample.cc new file mode 100644 index 00000000000..11f4c62bf1c --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/resample.cc @@ -0,0 +1,377 @@ +// feat/resample.cc + +// Copyright 2013 Pegah Ghahremani +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2014 Yanqing Sun, Junjie Wang +// 2014 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#include +#include +#include "feat/feature-functions.h" +#include "matrix/matrix-functions.h" +#include "feat/resample.h" + +namespace kaldi { + + +LinearResample::LinearResample(int32 samp_rate_in_hz, + int32 samp_rate_out_hz, + BaseFloat filter_cutoff_hz, + int32 num_zeros): + samp_rate_in_(samp_rate_in_hz), + samp_rate_out_(samp_rate_out_hz), + filter_cutoff_(filter_cutoff_hz), + num_zeros_(num_zeros) { + KALDI_ASSERT(samp_rate_in_hz > 0.0 && + samp_rate_out_hz > 0.0 && + filter_cutoff_hz > 0.0 && + filter_cutoff_hz*2 <= samp_rate_in_hz && + filter_cutoff_hz*2 <= samp_rate_out_hz && + num_zeros > 0); + + // base_freq is the frequency of the repeating unit, which is the gcd + // of the input frequencies. + int32 base_freq = Gcd(samp_rate_in_, samp_rate_out_); + input_samples_in_unit_ = samp_rate_in_ / base_freq; + output_samples_in_unit_ = samp_rate_out_ / base_freq; + + SetIndexesAndWeights(); + Reset(); +} + +int64 LinearResample::GetNumOutputSamples(int64 input_num_samp, + bool flush) const { + // For exact computation, we measure time in "ticks" of 1.0 / tick_freq, + // where tick_freq is the least common multiple of samp_rate_in_ and + // samp_rate_out_. + int32 tick_freq = Lcm(samp_rate_in_, samp_rate_out_); + int32 ticks_per_input_period = tick_freq / samp_rate_in_; + + // work out the number of ticks in the time interval + // [ 0, input_num_samp/samp_rate_in_ ). + int64 interval_length_in_ticks = input_num_samp * ticks_per_input_period; + if (!flush) { + BaseFloat window_width = num_zeros_ / (2.0 * filter_cutoff_); + // To count the window-width in ticks we take the floor. This + // is because since we're looking for the largest integer num-out-samp + // that fits in the interval, which is open on the right, a reduction + // in interval length of less than a tick will never make a difference. + // For example, the largest integer in the interval [ 0, 2 ) and the + // largest integer in the interval [ 0, 2 - 0.9 ) are the same (both one). + // So when we're subtracting the window-width we can ignore the fractional + // part. + int32 window_width_ticks = floor(window_width * tick_freq); + // The time-period of the output that we can sample gets reduced + // by the window-width (which is actually the distance from the + // center to the edge of the windowing function) if we're not + // "flushing the output". + interval_length_in_ticks -= window_width_ticks; + } + if (interval_length_in_ticks <= 0) + return 0; + int32 ticks_per_output_period = tick_freq / samp_rate_out_; + // Get the last output-sample in the closed interval, i.e. replacing [ ) with + // [ ]. Note: integer division rounds down. See + // http://en.wikipedia.org/wiki/Interval_(mathematics) for an explanation of + // the notation. + int64 last_output_samp = interval_length_in_ticks / ticks_per_output_period; + // We need the last output-sample in the open interval, so if it takes us to + // the end of the interval exactly, subtract one. + if (last_output_samp * ticks_per_output_period == interval_length_in_ticks) + last_output_samp--; + // First output-sample index is zero, so the number of output samples + // is the last output-sample plus one. + int64 num_output_samp = last_output_samp + 1; + return num_output_samp; +} + +void LinearResample::SetIndexesAndWeights() { + first_index_.resize(output_samples_in_unit_); + weights_.resize(output_samples_in_unit_); + + double window_width = num_zeros_ / (2.0 * filter_cutoff_); + + for (int32 i = 0; i < output_samples_in_unit_; i++) { + double output_t = i / static_cast(samp_rate_out_); + double min_t = output_t - window_width, max_t = output_t + window_width; + // we do ceil on the min and floor on the max, because if we did it + // the other way around we would unnecessarily include indexes just + // outside the window, with zero coefficients. It's possible + // if the arguments to the ceil and floor expressions are integers + // (e.g. if filter_cutoff_ has an exact ratio with the sample rates), + // that we unnecessarily include something with a zero coefficient, + // but this is only a slight efficiency issue. + int32 min_input_index = ceil(min_t * samp_rate_in_), + max_input_index = floor(max_t * samp_rate_in_), + num_indices = max_input_index - min_input_index + 1; + first_index_[i] = min_input_index; + weights_[i].Resize(num_indices); + for (int32 j = 0; j < num_indices; j++) { + int32 input_index = min_input_index + j; + double input_t = input_index / static_cast(samp_rate_in_), + delta_t = input_t - output_t; + // sign of delta_t doesn't matter. + weights_[i](j) = FilterFunc(delta_t) / samp_rate_in_; + } + } +} + + +// inline +void LinearResample::GetIndexes(int64 samp_out, + int64 *first_samp_in, + int32 *samp_out_wrapped) const { + // A unit is the smallest nonzero amount of time that is an exact + // multiple of the input and output sample periods. The unit index + // is the answer to "which numbered unit we are in". + int64 unit_index = samp_out / output_samples_in_unit_; + // samp_out_wrapped is equal to samp_out % output_samples_in_unit_ + *samp_out_wrapped = static_cast(samp_out - + unit_index * output_samples_in_unit_); + *first_samp_in = first_index_[*samp_out_wrapped] + + unit_index * input_samples_in_unit_; +} + + +void LinearResample::Resample(const VectorBase &input, + bool flush, + Vector *output) { + int32 input_dim = input.Dim(); + int64 tot_input_samp = input_sample_offset_ + input_dim, + tot_output_samp = GetNumOutputSamples(tot_input_samp, flush); + + KALDI_ASSERT(tot_output_samp >= output_sample_offset_); + + output->Resize(tot_output_samp - output_sample_offset_); + + // samp_out is the index into the total output signal, not just the part + // of it we are producing here. + for (int64 samp_out = output_sample_offset_; + samp_out < tot_output_samp; + samp_out++) { + int64 first_samp_in; + int32 samp_out_wrapped; + GetIndexes(samp_out, &first_samp_in, &samp_out_wrapped); + const Vector &weights = weights_[samp_out_wrapped]; + // first_input_index is the first index into "input" that we have a weight + // for. + int32 first_input_index = static_cast(first_samp_in - + input_sample_offset_); + BaseFloat this_output; + if (first_input_index >= 0 && + first_input_index + weights.Dim() <= input_dim) { + SubVector input_part(input, first_input_index, weights.Dim()); + this_output = VecVec(input_part, weights); + } else { // Handle edge cases. + this_output = 0.0; + for (int32 i = 0; i < weights.Dim(); i++) { + BaseFloat weight = weights(i); + int32 input_index = first_input_index + i; + if (input_index < 0 && input_remainder_.Dim() + input_index >= 0) { + this_output += weight * + input_remainder_(input_remainder_.Dim() + input_index); + } else if (input_index >= 0 && input_index < input_dim) { + this_output += weight * input(input_index); + } else if (input_index >= input_dim) { + // We're past the end of the input and are adding zero; should only + // happen if the user specified flush == true, or else we would not + // be trying to output this sample. + KALDI_ASSERT(flush); + } + } + } + int32 output_index = static_cast(samp_out - output_sample_offset_); + (*output)(output_index) = this_output; + } + + if (flush) { + Reset(); // Reset the internal state. + } else { + SetRemainder(input); + input_sample_offset_ = tot_input_samp; + output_sample_offset_ = tot_output_samp; + } +} + +void LinearResample::SetRemainder(const VectorBase &input) { + Vector old_remainder(input_remainder_); + // max_remainder_needed is the width of the filter from side to side, + // measured in input samples. you might think it should be half that, + // but you have to consider that you might be wanting to output samples + // that are "in the past" relative to the beginning of the latest + // input... anyway, storing more remainder than needed is not harmful. + int32 max_remainder_needed = ceil(samp_rate_in_ * num_zeros_ / + filter_cutoff_); + input_remainder_.Resize(max_remainder_needed); + for (int32 index = - input_remainder_.Dim(); index < 0; index++) { + // we interpret "index" as an offset from the end of "input" and + // from the end of input_remainder_. + int32 input_index = index + input.Dim(); + if (input_index >= 0) + input_remainder_(index + input_remainder_.Dim()) = input(input_index); + else if (input_index + old_remainder.Dim() >= 0) + input_remainder_(index + input_remainder_.Dim()) = + old_remainder(input_index + old_remainder.Dim()); + // else leave it at zero. + } +} + +void LinearResample::Reset() { + input_sample_offset_ = 0; + output_sample_offset_ = 0; + input_remainder_.Resize(0); +} + +/** Here, t is a time in seconds representing an offset from + the center of the windowed filter function, and FilterFunction(t) + returns the windowed filter function, described + in the header as h(t) = f(t)g(t), evaluated at t. +*/ +BaseFloat LinearResample::FilterFunc(BaseFloat t) const { + BaseFloat window, // raised-cosine (Hanning) window of width + // num_zeros_/2*filter_cutoff_ + filter; // sinc filter function + if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_)) + window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t)); + else + window = 0.0; // outside support of window function + if (t != 0) + filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t); + else + filter = 2 * filter_cutoff_; // limit of the function at t = 0 + return filter * window; +} + + +ArbitraryResample::ArbitraryResample( + int32 num_samples_in, BaseFloat samp_rate_in, + BaseFloat filter_cutoff, const Vector &sample_points, + int32 num_zeros): + num_samples_in_(num_samples_in), + samp_rate_in_(samp_rate_in), + filter_cutoff_(filter_cutoff), + num_zeros_(num_zeros) { + KALDI_ASSERT(num_samples_in > 0 && samp_rate_in > 0.0 && + filter_cutoff > 0.0 && + filter_cutoff * 2.0 <= samp_rate_in + && num_zeros > 0); + // set up weights_ and indices_. Please try to keep all functions short and + SetIndexes(sample_points); + SetWeights(sample_points); +} + + +void ArbitraryResample::Resample(const MatrixBase &input, + MatrixBase *output) const { + // each row of "input" corresponds to the data to resample; + // the corresponding row of "output" is the resampled data. + + KALDI_ASSERT(input.NumRows() == output->NumRows() && + input.NumCols() == num_samples_in_ && + output->NumCols() == weights_.size()); + + Vector output_col(output->NumRows()); + for (int32 i = 0; i < NumSamplesOut(); i++) { + SubMatrix input_part(input, 0, input.NumRows(), + first_index_[i], + weights_[i].Dim()); + const Vector &weight_vec(weights_[i]); + output_col.AddMatVec(1.0, input_part, + kNoTrans, weight_vec, 0.0); + output->CopyColFromVec(output_col, i); + } +} + +void ArbitraryResample::Resample(const VectorBase &input, + VectorBase *output) const { + KALDI_ASSERT(input.Dim() == num_samples_in_ && + output->Dim() == weights_.size()); + + int32 output_dim = output->Dim(); + for (int32 i = 0; i < output_dim; i++) { + SubVector input_part(input, first_index_[i], weights_[i].Dim()); + (*output)(i) = VecVec(input_part, weights_[i]); + } +} + +void ArbitraryResample::SetIndexes(const Vector &sample_points) { + int32 num_samples = sample_points.Dim(); + first_index_.resize(num_samples); + weights_.resize(num_samples); + BaseFloat filter_width = num_zeros_ / (2.0 * filter_cutoff_); + for (int32 i = 0; i < num_samples; i++) { + // the t values are in seconds. + BaseFloat t = sample_points(i), + t_min = t - filter_width, t_max = t + filter_width; + int32 index_min = ceil(samp_rate_in_ * t_min), + index_max = floor(samp_rate_in_ * t_max); + // the ceil on index min and the floor on index_max are because there + // is no point using indices just outside the window (coeffs would be zero). + if (index_min < 0) + index_min = 0; + if (index_max >= num_samples_in_) + index_max = num_samples_in_ - 1; + first_index_[i] = index_min; + weights_[i].Resize(index_max - index_min + 1); + } +} + +void ArbitraryResample::SetWeights(const Vector &sample_points) { + int32 num_samples_out = NumSamplesOut(); + for (int32 i = 0; i < num_samples_out; i++) { + for (int32 j = 0 ; j < weights_[i].Dim(); j++) { + BaseFloat delta_t = sample_points(i) - + (first_index_[i] + j) / samp_rate_in_; + // Include at this point the factor of 1.0 / samp_rate_in_ which + // appears in the math. + weights_[i](j) = FilterFunc(delta_t) / samp_rate_in_; + } + } +} + +/** Here, t is a time in seconds representing an offset from + the center of the windowed filter function, and FilterFunction(t) + returns the windowed filter function, described + in the header as h(t) = f(t)g(t), evaluated at t. +*/ +BaseFloat ArbitraryResample::FilterFunc(BaseFloat t) const { + BaseFloat window, // raised-cosine (Hanning) window of width + // num_zeros_/2*filter_cutoff_ + filter; // sinc filter function + if (fabs(t) < num_zeros_ / (2.0 * filter_cutoff_)) + window = 0.5 * (1 + cos(M_2PI * filter_cutoff_ / num_zeros_ * t)); + else + window = 0.0; // outside support of window function + if (t != 0.0) + filter = sin(M_2PI * filter_cutoff_ * t) / (M_PI * t); + else + filter = 2.0 * filter_cutoff_; // limit of the function at zero. + return filter * window; +} + +void ResampleWaveform(BaseFloat orig_freq, const VectorBase &wave, + BaseFloat new_freq, Vector *new_wave) { + BaseFloat min_freq = std::min(orig_freq, new_freq); + BaseFloat lowpass_cutoff = 0.99 * 0.5 * min_freq; + int32 lowpass_filter_width = 6; + LinearResample resampler(orig_freq, new_freq, + lowpass_cutoff, lowpass_filter_width); + resampler.Resample(wave, true, new_wave); +} +} // namespace kaldi diff --git a/torchaudio/csrc/kaldi/feat/resample.h b/torchaudio/csrc/kaldi/feat/resample.h new file mode 100644 index 00000000000..e0b4688c99b --- /dev/null +++ b/torchaudio/csrc/kaldi/feat/resample.h @@ -0,0 +1,287 @@ +// feat/resample.h + +// Copyright 2013 Pegah Ghahremani +// 2014 IMSL, PKU-HKUST (author: Wei Shi) +// 2014 Yanqing Sun, Junjie Wang +// 2014 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + + +#ifndef KALDI_FEAT_RESAMPLE_H_ +#define KALDI_FEAT_RESAMPLE_H_ + +#include +#include +#include +#include + + +#include "matrix/matrix-lib.h" +#include "util/common-utils.h" +#include "base/kaldi-error.h" + +namespace kaldi { +/// @addtogroup feat FeatureExtraction +/// @{ + +/** + \file[resample.h] + + This header contains declarations of classes for resampling signals. The + normal cases of resampling a signal are upsampling and downsampling + (increasing and decreasing the sample rate of a signal, respectively), + although the ArbitraryResample class allows a more generic case where + we want to get samples of a signal at uneven intervals (for instance, + log-spaced). + + The input signal is always evenly spaced, say sampled with frequency S, and + we assume the original signal was band-limited to S/2 or lower. The n'th + input sample x_n (with n = 0, 1, ...) is interpreted as the original + signal's value at time n/S. + + For resampling, it is convenient to view the input signal as a + continuous function x(t) of t, where each sample x_n becomes a delta function + with magnitude x_n/S, at time n/S. If we band limit this to the Nyquist + frequency S/2, we can show that this is the same as the original signal + that was sampled. [assuming the original signal was periodic and band + limited.] In general we want to bandlimit to lower than S/2, because + we don't have a perfect filter and also because if we want to resample + at a lower frequency than S, we need to bandlimit to below half of that. + Anyway, suppose we want to bandlimit to C, with 0 < C < S/2. The perfect + rectangular filter with cutoff C is the sinc function, + \f[ f(t) = 2C sinc(2Ct), \f] + where sinc is the normalized sinc function \f$ sinc(t) = sin(pi t) / (pi t) \f$, with + \f$ sinc(0) = 1 \f$. This is not a practical filter, though, because it has + infinite support. At the cost of less-than-perfect rolloff, we can choose + a suitable windowing function g(t), and use f(t) g(t) as the filter. For + a windowing function we choose raised-cosine (Hanning) window with support + on [-w/2C, w/2C], where w >= 2 is an integer chosen by the user. w = 1 + means we window the sinc function out to its first zero on the left and right, + w = 2 means the second zero, and so on; we normally choose w to be at least two. + We call this num_zeros, not w, in the code. + + Convolving the signal x(t) with this windowed filter h(t) = f(t)g(t) and evaluating the resulting + signal s(t) at an arbitrary time t is easy: we have + \f[ s(t) = 1/S \sum_n x_n h(t - n/S) \f]. + (note: the sign of t - n/S might be wrong, but it doesn't matter as the filter + and window are symmetric). + This is true for arbitrary values of t. What the class ArbitraryResample does + is to allow you to evaluate the signal for specified values of t. +*/ + + +/** + Class ArbitraryResample allows you to resample a signal (assumed zero outside + the sample region, not periodic) at arbitrary specified time values, which + don't have to be linearly spaced. The low-pass filter cutoff + "filter_cutoff_hz" should be less than half the sample rate; + "num_zeros" should probably be at least two preferably more; higher numbers give + sharper filters but will be less efficient. +*/ +class ArbitraryResample { + public: + ArbitraryResample(int32 num_samples_in, + BaseFloat samp_rate_hz, + BaseFloat filter_cutoff_hz, + const Vector &sample_points_secs, + int32 num_zeros); + + int32 NumSamplesIn() const { return num_samples_in_; } + + int32 NumSamplesOut() const { return weights_.size(); } + + /// This function does the resampling. + /// input.NumRows() and output.NumRows() should be equal + /// and nonzero. + /// input.NumCols() should equal NumSamplesIn() + /// and output.NumCols() should equal NumSamplesOut(). + void Resample(const MatrixBase &input, + MatrixBase *output) const; + + /// This version of the Resample function processes just + /// one vector. + void Resample(const VectorBase &input, + VectorBase *output) const; + private: + void SetIndexes(const Vector &sample_points); + + void SetWeights(const Vector &sample_points); + + BaseFloat FilterFunc(BaseFloat t) const; + + int32 num_samples_in_; + BaseFloat samp_rate_in_; + BaseFloat filter_cutoff_; + int32 num_zeros_; + + std::vector first_index_; // The first input-sample index that we sum + // over, for this output-sample index. + std::vector > weights_; +}; + + +/** + LinearResample is a special case of ArbitraryResample, where we want to + resample a signal at linearly spaced intervals (this means we want to + upsample or downsample the signal). It is more efficient than + ArbitraryResample because we can construct it just once. + + We require that the input and output sampling rate be specified as + integers, as this is an easy way to specify that their ratio be rational. +*/ + +class LinearResample { + public: + /// Constructor. We make the input and output sample rates integers, because + /// we are going to need to find a common divisor. This should just remind + /// you that they need to be integers. The filter cutoff needs to be less + /// than samp_rate_in_hz/2 and less than samp_rate_out_hz/2. num_zeros + /// controls the sharpness of the filter, more == sharper but less efficient. + /// We suggest around 4 to 10 for normal use. + LinearResample(int32 samp_rate_in_hz, + int32 samp_rate_out_hz, + BaseFloat filter_cutoff_hz, + int32 num_zeros); + + /// This function does the resampling. If you call it with flush == true and + /// you have never called it with flush == false, it just resamples the input + /// signal (it resizes the output to a suitable number of samples). + /// + /// You can also use this function to process a signal a piece at a time. + /// suppose you break it into piece1, piece2, ... pieceN. You can call + /// \code{.cc} + /// Resample(piece1, &output1, false); + /// Resample(piece2, &output2, false); + /// Resample(piece3, &output3, true); + /// \endcode + /// If you call it with flush == false, it won't output the last few samples + /// but will remember them, so that if you later give it a second piece of + /// the input signal it can process it correctly. + /// If your most recent call to the object was with flush == false, it will + /// have internal state; you can remove this by calling Reset(). + /// Empty input is acceptable. + void Resample(const VectorBase &input, + bool flush, + Vector *output); + + /// Calling the function Reset() resets the state of the object prior to + /// processing a new signal; it is only necessary if you have called + /// Resample(x, y, false) for some signal, leading to a remainder of the + /// signal being called, but then abandon processing the signal before calling + /// Resample(x, y, true) for the last piece. Call it unnecessarily between + /// signals will not do any harm. + void Reset(); + + //// Return the input and output sampling rates (for checks, for example) + inline int32 GetInputSamplingRate() { return samp_rate_in_; } + inline int32 GetOutputSamplingRate() { return samp_rate_out_; } + private: + /// This function outputs the number of output samples we will output + /// for a signal with "input_num_samp" input samples. If flush == true, + /// we return the largest n such that + /// (n/samp_rate_out_) is in the interval [ 0, input_num_samp/samp_rate_in_ ), + /// and note that the interval is half-open. If flush == false, + /// define window_width as num_zeros / (2.0 * filter_cutoff_); + /// we return the largest n such that (n/samp_rate_out_) is in the interval + /// [ 0, input_num_samp/samp_rate_in_ - window_width ). + int64 GetNumOutputSamples(int64 input_num_samp, bool flush) const; + + + /// Given an output-sample index, this function outputs to *first_samp_in the + /// first input-sample index that we have a weight on (may be negative), + /// and to *samp_out_wrapped the index into weights_ where we can get the + /// corresponding weights on the input. + inline void GetIndexes(int64 samp_out, + int64 *first_samp_in, + int32 *samp_out_wrapped) const; + + void SetRemainder(const VectorBase &input); + + void SetIndexesAndWeights(); + + BaseFloat FilterFunc(BaseFloat) const; + + // The following variables are provided by the user. + int32 samp_rate_in_; + int32 samp_rate_out_; + BaseFloat filter_cutoff_; + int32 num_zeros_; + + int32 input_samples_in_unit_; ///< The number of input samples in the + ///< smallest repeating unit: num_samp_in_ = + ///< samp_rate_in_hz / Gcd(samp_rate_in_hz, + ///< samp_rate_out_hz) + int32 output_samples_in_unit_; ///< The number of output samples in the + ///< smallest repeating unit: num_samp_out_ = + ///< samp_rate_out_hz / Gcd(samp_rate_in_hz, + ///< samp_rate_out_hz) + + + /// The first input-sample index that we sum over, for this output-sample + /// index. May be negative; any truncation at the beginning is handled + /// separately. This is just for the first few output samples, but we can + /// extrapolate the correct input-sample index for arbitrary output samples. + std::vector first_index_; + + /// Weights on the input samples, for this output-sample index. + std::vector > weights_; + + // the following variables keep track of where we are in a particular signal, + // if it is being provided over multiple calls to Resample(). + + int64 input_sample_offset_; ///< The number of input samples we have + ///< already received for this signal + ///< (including anything in remainder_) + int64 output_sample_offset_; ///< The number of samples we have already + ///< output for this signal. + Vector input_remainder_; ///< A small trailing part of the + ///< previously seen input signal. +}; + +/** + Downsample or upsample a waveform. This is a convenience wrapper for the + class 'LinearResample'. + The low-pass filter cutoff used in 'LinearResample' is 0.99 of the Nyquist, + where the Nyquist is half of the minimum of (orig_freq, new_freq). The + resampling is done with a symmetric FIR filter with N_z (number of zeros) + as 6. + + We compared the downsampling results with those from the sox resampling + toolkit. + Sox's design is inspired by Laurent De Soras' paper, + https://ccrma.stanford.edu/~jos/resample/Implementation.html + + Note: we expect that while orig_freq and new_freq are of type BaseFloat, they + are actually required to have exact integer values (like 16000 or 8000) with + a ratio between them that can be expressed as a rational number with + reasonably small integer factors. +*/ +void ResampleWaveform(BaseFloat orig_freq, const VectorBase &wave, + BaseFloat new_freq, Vector *new_wave); + + +/// This function is deprecated. It is provided for backward compatibility, to avoid +/// breaking older code. +inline void DownsampleWaveForm(BaseFloat orig_freq, const VectorBase &wave, + BaseFloat new_freq, Vector *new_wave) { + ResampleWaveform(orig_freq, wave, new_freq, new_wave); +} + + +/// @} End of "addtogroup feat" +} // namespace kaldi +#endif // KALDI_FEAT_RESAMPLE_H_ diff --git a/torchaudio/csrc/kaldi/itf/online-feature-itf.h b/torchaudio/csrc/kaldi/itf/online-feature-itf.h new file mode 100644 index 00000000000..3d139b461f0 --- /dev/null +++ b/torchaudio/csrc/kaldi/itf/online-feature-itf.h @@ -0,0 +1,125 @@ +// itf/online-feature-itf.h + +// Copyright 2013 Johns Hopkins University (author: Daniel Povey) + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_ITF_ONLINE_FEATURE_ITF_H_ +#define KALDI_ITF_ONLINE_FEATURE_ITF_H_ 1 +#include "base/kaldi-common.h" +#include "matrix/matrix-lib.h" + +namespace kaldi { +/// @ingroup Interfaces +/// @{ + +/** + OnlineFeatureInterface is an interface for online feature processing (it is + also usable in the offline setting, but currently we're not using it for + that). This is for use in the online2/ directory, and it supersedes the + interface in ../online/online-feat-input.h. We have a slightly different + model that puts more control in the hands of the calling thread, and won't + involve waiting on semaphores in the decoding thread. + + This interface only specifies how the object *outputs* the features. + How it obtains the features, e.g. from a previous object or objects of type + OnlineFeatureInterface, is not specified in the interface and you will + likely define new constructors or methods in the derived type to do that. + + You should appreciate that this interface is designed to allow random + access to features, as long as they are ready. That is, the user + can call GetFrame for any frame less than NumFramesReady(), and when + implementing a child class you must not make assumptions about the + order in which the user makes these calls. +*/ + +class OnlineFeatureInterface { + public: + virtual int32 Dim() const = 0; /// returns the feature dimension. + + /// Returns the total number of frames, since the start of the utterance, that + /// are now available. In an online-decoding context, this will likely + /// increase with time as more data becomes available. + virtual int32 NumFramesReady() const = 0; + + /// Returns true if this is the last frame. Frame indices are zero-based, so the + /// first frame is zero. IsLastFrame(-1) will return false, unless the file + /// is empty (which is a case that I'm not sure all the code will handle, so + /// be careful). This function may return false for some frame if + /// we haven't yet decided to terminate decoding, but later true if we decide + /// to terminate decoding. This function exists mainly to correctly handle + /// end effects in feature extraction, and is not a mechanism to determine how + /// many frames are in the decodable object (as it used to be, and for backward + /// compatibility, still is, in the Decodable interface). + virtual bool IsLastFrame(int32 frame) const = 0; + + /// Gets the feature vector for this frame. Before calling this for a given + /// frame, it is assumed that you called NumFramesReady() and it returned a + /// number greater than "frame". Otherwise this call will likely crash with + /// an assert failure. This function is not declared const, in case there is + /// some kind of caching going on, but most of the time it shouldn't modify + /// the class. + virtual void GetFrame(int32 frame, VectorBase *feat) = 0; + + + /// This is like GetFrame() but for a collection of frames. There is a + /// default implementation that just gets the frames one by one, but it + /// may be overridden for efficiency by child classes (since sometimes + /// it's more efficient to do things in a batch). + virtual void GetFrames(const std::vector &frames, + MatrixBase *feats) { + KALDI_ASSERT(static_cast(frames.size()) == feats->NumRows()); + for (size_t i = 0; i < frames.size(); i++) { + SubVector feat(*feats, i); + GetFrame(frames[i], &feat); + } + } + + + // Returns frame shift in seconds. Helps to estimate duration from frame + // counts. + virtual BaseFloat FrameShiftInSeconds() const = 0; + + /// Virtual destructor. Note: constructors that take another member of + /// type OnlineFeatureInterface are not expected to take ownership of + /// that pointer; the caller needs to keep track of that manually. + virtual ~OnlineFeatureInterface() { } + +}; + + +/// Add a virtual class for "source" features such as MFCC or PLP or pitch +/// features. +class OnlineBaseFeature: public OnlineFeatureInterface { + public: + /// This would be called from the application, when you get more wave data. + /// Note: the sampling_rate is typically only provided so the code can assert + /// that it matches the sampling rate expected in the options. + virtual void AcceptWaveform(BaseFloat sampling_rate, + const VectorBase &waveform) = 0; + + /// InputFinished() tells the class you won't be providing any + /// more waveform. This will help flush out the last few frames + /// of delta or LDA features (it will typically affect the return value + /// of IsLastFrame. + virtual void InputFinished() = 0; +}; + + +/// @} +} // namespace Kaldi + +#endif // KALDI_ITF_ONLINE_FEATURE_ITF_H_ diff --git a/torchaudio/csrc/kaldi/itf/options-itf.h b/torchaudio/csrc/kaldi/itf/options-itf.h new file mode 100644 index 00000000000..204f46d6698 --- /dev/null +++ b/torchaudio/csrc/kaldi/itf/options-itf.h @@ -0,0 +1,49 @@ +// itf/options-itf.h + +// Copyright 2013 Tanel Alumae, Tallinn University of Technology + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_ITF_OPTIONS_ITF_H_ +#define KALDI_ITF_OPTIONS_ITF_H_ 1 +#include "base/kaldi-common.h" + +namespace kaldi { + +class OptionsItf { + public: + + virtual void Register(const std::string &name, + bool *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + int32 *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + uint32 *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + float *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + double *ptr, const std::string &doc) = 0; + virtual void Register(const std::string &name, + std::string *ptr, const std::string &doc) = 0; + + virtual ~OptionsItf() {} +}; + +} // namespace Kaldi + +#endif // KALDI_ITF_OPTIONS_ITF_H_ + + diff --git a/torchaudio/csrc/kaldi/kaldi.cc b/torchaudio/csrc/kaldi/kaldi.cc new file mode 100644 index 00000000000..2e586112b2b --- /dev/null +++ b/torchaudio/csrc/kaldi/kaldi.cc @@ -0,0 +1,69 @@ +#include +#include + +namespace { + torch::Tensor denormalize(const torch::Tensor& t) { + auto ret = t; + auto pos = t > 0, neg = t < 0; + ret.index_put({pos}, t.index({pos}) * 32767); + ret.index_put({neg}, t.index({neg}) * 32768); + return ret; + } +} // namespace + +namespace torchaudio { +namespace kaldi { + + torch::Tensor ComputeKaldiPitch( + const torch::Tensor &wave, + double sample_frequency, + double frame_length, + double frame_shift, + double preemphasis_coefficient, + double min_f0, + double max_f0, + double soft_min_f0, + double penalty_factor, + double lowpass_cutoff, + double resample_frequency, + double delta_pitch, + double nccf_ballast, + int64_t lowpass_filter_width, + int64_t upsample_filter_width, + int64_t max_frames_latency, + int64_t frames_per_chunk, + bool simulate_first_pass_online, + int64_t recompute_frame, + bool nccf_ballast_online, + bool snip_edges + ) { + // Kaldi's float type expects value range of int16 + ::kaldi::VectorBase<::kaldi::BaseFloat> input(denormalize(wave)); + ::kaldi::PitchExtractionOptions opts; + opts.samp_freq = static_cast<::kaldi::BaseFloat>(sample_frequency); + opts.frame_shift_ms = static_cast<::kaldi::BaseFloat>(frame_shift); + opts.frame_length_ms = static_cast<::kaldi::BaseFloat>(frame_length); + opts.preemph_coeff = static_cast<::kaldi::BaseFloat>(preemphasis_coefficient); + opts.min_f0 = static_cast<::kaldi::BaseFloat>(min_f0); + opts.max_f0 = static_cast<::kaldi::BaseFloat>(max_f0); + opts.soft_min_f0 = static_cast<::kaldi::BaseFloat>(soft_min_f0); + opts.penalty_factor = static_cast<::kaldi::BaseFloat>(penalty_factor); + opts.lowpass_cutoff = static_cast<::kaldi::BaseFloat>(lowpass_cutoff); + opts.resample_freq = static_cast<::kaldi::BaseFloat>(resample_frequency); + opts.delta_pitch = static_cast<::kaldi::BaseFloat>(delta_pitch); + opts.nccf_ballast = static_cast<::kaldi::BaseFloat>(nccf_ballast); + opts.lowpass_filter_width = static_cast<::kaldi::int32>(lowpass_filter_width); + opts.upsample_filter_width = static_cast<::kaldi::int32>(upsample_filter_width); + opts.max_frames_latency = static_cast<::kaldi::int32>(max_frames_latency); + opts.frames_per_chunk = static_cast<::kaldi::int32>(frames_per_chunk); + opts.simulate_first_pass_online = simulate_first_pass_online; + opts.recompute_frame = static_cast<::kaldi::int32>(recompute_frame); + opts.nccf_ballast_online = nccf_ballast_online; + opts.snip_edges = snip_edges; + ::kaldi::Matrix<::kaldi::BaseFloat> output; + ::kaldi::ComputeKaldiPitch(opts, input, &output); + return output.tensor_; + } + +} // namespace kaldi +} // namespace torchaudio diff --git a/torchaudio/csrc/kaldi/kaldi.h b/torchaudio/csrc/kaldi/kaldi.h new file mode 100644 index 00000000000..98787082d81 --- /dev/null +++ b/torchaudio/csrc/kaldi/kaldi.h @@ -0,0 +1,35 @@ +#ifndef TORCHAUDIO_CSRC_KALDI_WRAPPER_H +#define TORCHAUDIO_CSRC_KALDI_WRAPPER_H + +#include + +namespace torchaudio { +namespace kaldi { + + torch::Tensor ComputeKaldiPitch( + const torch::Tensor &wave, + double sample_frequency, + double frame_length, + double frame_shift, + double preemphasis_coefficient, + double min_f0, + double max_f0, + double soft_min_f0, + double penalty_factor, + double lowpass_cutoff, + double resample_frequency, + double delta_pitch, + double nccf_ballast, + int64_t lowpass_filter_width, + int64_t upsample_filter_width, + int64_t max_frames_latency, + int64_t frames_per_chunk, + bool simulate_first_pass_online, + int64_t recompute_frame, + bool nccf_ballast_online, + bool snip_edges); + +} // namespace kaldi +} // namespace torchaudio + +#endif diff --git a/torchaudio/csrc/kaldi/matrix/compressed-matrix.h b/torchaudio/csrc/kaldi/matrix/compressed-matrix.h new file mode 100644 index 00000000000..78105b9b58a --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/compressed-matrix.h @@ -0,0 +1,283 @@ +// matrix/compressed-matrix.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) +// Frantisek Skala, Wei Shi + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_MATRIX_COMPRESSED_MATRIX_H_ +#define KALDI_MATRIX_COMPRESSED_MATRIX_H_ 1 + +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// \addtogroup matrix_group +/// @{ + + + +/* + The enum CompressionMethod is used when creating a CompressedMatrix (a lossily + compressed matrix) from a regular Matrix. It dictates how we choose the + compressed format and how we choose the ranges of floats that are represented + by particular integers. + + kAutomaticMethod = 1 This is the default when you don't specify the + compression method. It is a shorthand for using + kSpeechFeature if the num-rows is more than 8, and + kTwoByteAuto otherwise. + kSpeechFeature = 2 This is the most complicated of the compression methods, + and was designed for speech features which have a roughly + Gaussian distribution with different ranges for each + dimension. Each element is stored in one byte, but there + is an 8-byte header per column; the spacing of the + integer values is not uniform but is in 3 ranges. + kTwoByteAuto = 3 Each element is stored in two bytes as a uint16, with + the representable range of values chosen automatically + with the minimum and maximum elements of the matrix as + its edges. + kTwoByteSignedInteger = 4 + Each element is stored in two bytes as a uint16, with + the representable range of value chosen to coincide with + what you'd get if you stored signed integers, i.e. + [-32768.0, 32767.0]. Suitable for waveform data that + was previously stored as 16-bit PCM. + kOneByteAuto = 5 Each element is stored in one byte as a uint8, with the + representable range of values chosen automatically with + the minimum and maximum elements of the matrix as its + edges. + kOneByteUnsignedInteger = 6 Each element is stored in + one byte as a uint8, with the representable range of + values equal to [0.0, 255.0]. + kOneByteZeroOne = 7 Each element is stored in + one byte as a uint8, with the representable range of + values equal to [0.0, 1.0]. Suitable for image data + that has previously been compressed as int8. + + // We can add new methods here as needed: if they just imply different ways + // of selecting the min_value and range, and a num-bytes = 1 or 2, they will + // be trivial to implement. +*/ +enum CompressionMethod { + kAutomaticMethod = 1, + kSpeechFeature = 2, + kTwoByteAuto = 3, + kTwoByteSignedInteger = 4, + kOneByteAuto = 5, + kOneByteUnsignedInteger = 6, + kOneByteZeroOne = 7 +}; + + +/* + This class does lossy compression of a matrix. It supports various compression + methods, see enum CompressionMethod. +*/ + +class CompressedMatrix { + public: + CompressedMatrix(): data_(NULL) { } + + ~CompressedMatrix() { Clear(); } + + template + explicit CompressedMatrix(const MatrixBase &mat, + CompressionMethod method = kAutomaticMethod): + data_(NULL) { CopyFromMat(mat, method); } + + /// Initializer that can be used to select part of an existing + /// CompressedMatrix without un-compressing and re-compressing (note: unlike + /// similar initializers for class Matrix, it doesn't point to the same memory + /// location). + /// + /// This creates a CompressedMatrix with the size (num_rows, num_cols) + /// starting at (row_offset, col_offset). + /// + /// If you specify allow_padding = true, + /// it is permitted to have row_offset < 0 and + /// row_offset + num_rows > mat.NumRows(), and the result will contain + /// repeats of the first and last rows of 'mat' as necessary. + CompressedMatrix(const CompressedMatrix &mat, + const MatrixIndexT row_offset, + const MatrixIndexT num_rows, + const MatrixIndexT col_offset, + const MatrixIndexT num_cols, + bool allow_padding = false); + + void *Data() const { return this->data_; } + + /// This will resize *this and copy the contents of mat to *this. + template + void CopyFromMat(const MatrixBase &mat, + CompressionMethod method = kAutomaticMethod); + + CompressedMatrix(const CompressedMatrix &mat); + + CompressedMatrix &operator = (const CompressedMatrix &mat); // assignment operator. + + template + CompressedMatrix &operator = (const MatrixBase &mat); // assignment operator. + + /// Copies contents to matrix. Note: mat must have the correct size. + /// The kTrans case uses a temporary. + template + void CopyToMat(MatrixBase *mat, + MatrixTransposeType trans = kNoTrans) const; + + void Write(std::ostream &os, bool binary) const; + + void Read(std::istream &is, bool binary); + + /// Returns number of rows (or zero for emtpy matrix). + inline MatrixIndexT NumRows() const { return (data_ == NULL) ? 0 : + (*reinterpret_cast(data_)).num_rows; } + + /// Returns number of columns (or zero for emtpy matrix). + inline MatrixIndexT NumCols() const { return (data_ == NULL) ? 0 : + (*reinterpret_cast(data_)).num_cols; } + + /// Copies row #row of the matrix into vector v. + /// Note: v must have same size as #cols. + template + void CopyRowToVec(MatrixIndexT row, VectorBase *v) const; + + /// Copies column #col of the matrix into vector v. + /// Note: v must have same size as #rows. + template + void CopyColToVec(MatrixIndexT col, VectorBase *v) const; + + /// Copies submatrix of compressed matrix into matrix dest. + /// Submatrix starts at row row_offset and column column_offset and its size + /// is defined by size of provided matrix dest + template + void CopyToMat(int32 row_offset, + int32 column_offset, + MatrixBase *dest) const; + + void Swap(CompressedMatrix *other) { std::swap(data_, other->data_); } + + void Clear(); + + /// scales all elements of matrix by alpha. + /// It scales the floating point values in GlobalHeader by alpha. + void Scale(float alpha); + + friend class Matrix; + friend class Matrix; + private: + + // This enum describes the different compressed-data formats: these are + // distinct from the compression methods although all of the methods apart + // from kAutomaticMethod dictate a particular compressed-data format. + // + // kOneByteWithColHeaders means there is a GlobalHeader and each + // column has a PerColHeader; the actual data is stored in + // one byte per element, in column-major order (the mapping + // from integers to floats is a little complicated). + // kTwoByte means there is a global header but no PerColHeader; + // the actual data is stored in two bytes per element in + // row-major order; it's decompressed as: + // uint16 i; GlobalHeader g; + // float f = g.min_value + i * (g.range / 65535.0) + // kOneByte means there is a global header but not PerColHeader; + // the data is stored in one byte per element in row-major + // order and is decompressed as: + // uint8 i; GlobalHeader g; + // float f = g.min_value + i * (g.range / 255.0) + enum DataFormat { + kOneByteWithColHeaders = 1, + kTwoByte = 2, + kOneByte = 3 + }; + + + // allocates data using new [], ensures byte alignment + // sufficient for float. + static void *AllocateData(int32 num_bytes); + + struct GlobalHeader { + int32 format; // Represents the enum DataFormat. + float min_value; // min_value and range represent the ranges of the integer + // data in the kTwoByte and kOneByte formats, and the + // range of the PerColHeader uint16's in the + // kOneByteWithColheaders format. + float range; + int32 num_rows; + int32 num_cols; + }; + + // This function computes the global header for compressing this data. + template + static inline void ComputeGlobalHeader(const MatrixBase &mat, + CompressionMethod method, + GlobalHeader *header); + + + // The number of bytes we need to request when allocating 'data_'. + static MatrixIndexT DataSize(const GlobalHeader &header); + + // This struct is only used in format kOneByteWithColHeaders. + struct PerColHeader { + uint16 percentile_0; + uint16 percentile_25; + uint16 percentile_75; + uint16 percentile_100; + }; + + template + static void CompressColumn(const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, PerColHeader *header, + uint8 *byte_data); + template + static void ComputeColHeader(const GlobalHeader &global_header, + const Real *data, MatrixIndexT stride, + int32 num_rows, PerColHeader *header); + + static inline uint16 FloatToUint16(const GlobalHeader &global_header, + float value); + + // this is used only in the kOneByte compression format. + static inline uint8 FloatToUint8(const GlobalHeader &global_header, + float value); + + static inline float Uint16ToFloat(const GlobalHeader &global_header, + uint16 value); + + // this is used only in the kOneByteWithColHeaders compression format. + static inline uint8 FloatToChar(float p0, float p25, + float p75, float p100, + float value); + + // this is used only in the kOneByteWithColHeaders compression format. + static inline float CharToFloat(float p0, float p25, + float p75, float p100, + uint8 value); + + void *data_; // first GlobalHeader, then PerColHeader (repeated), then + // the byte data for each column (repeated). Note: don't intersperse + // the byte data with the PerColHeaders, because of alignment issues. + +}; + +/// @} end of \addtogroup matrix_group + + +} // namespace kaldi + + +#endif // KALDI_MATRIX_COMPRESSED_MATRIX_H_ diff --git a/torchaudio/csrc/kaldi/matrix/kaldi-matrix.cc b/torchaudio/csrc/kaldi/matrix/kaldi-matrix.cc new file mode 100644 index 00000000000..a703c66c13c --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/kaldi-matrix.cc @@ -0,0 +1,36 @@ +#include "matrix/kaldi-matrix.h" + +namespace { + +template +void assert_matrix_shape(const torch::Tensor &tensor_); + +template<> +void assert_matrix_shape(const torch::Tensor &tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32); +} + +template<> +void assert_matrix_shape(const torch::Tensor &tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 2); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64); +} + +} // namespace + +namespace kaldi { + +template +MatrixBase::MatrixBase(torch::Tensor tensor) : tensor_(tensor) { + assert_matrix_shape(tensor_); +}; + +template class Matrix; +template class Matrix; +template class MatrixBase; +template class MatrixBase; +template class SubMatrix; +template class SubMatrix; + +} // namespace kaldi diff --git a/torchaudio/csrc/kaldi/matrix/kaldi-matrix.h b/torchaudio/csrc/kaldi/matrix/kaldi-matrix.h new file mode 100644 index 00000000000..b534040cbe1 --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/kaldi-matrix.h @@ -0,0 +1,163 @@ +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h + +#ifndef KALDI_MATRIX_KALDI_MATRIX_H_ +#define KALDI_MATRIX_KALDI_MATRIX_H_ + +#include +#include "matrix/matrix-common.h" +#include "matrix/kaldi-vector.h" + +using namespace torch::indexing; + +namespace kaldi { + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L44-L48 +template +class MatrixBase { +public: + //////////////////////////////////////////////////////////////////////////////// + // PyTorch-specific items + //////////////////////////////////////////////////////////////////////////////// + torch::Tensor tensor_; + /// Construct VectorBase which is an interface to an existing torch::Tensor object. + MatrixBase(torch::Tensor tensor); + + //////////////////////////////////////////////////////////////////////////////// + // Kaldi-compatible items + //////////////////////////////////////////////////////////////////////////////// + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L62-L63 + inline MatrixIndexT NumRows() const { return tensor_.size(0); }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L65-L66 + inline MatrixIndexT NumCols() const { return tensor_.size(1); }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L177-L178 + void CopyColFromVec(const VectorBase &v, const MatrixIndexT col) { + tensor_.index_put_({Slice(), col}, v.tensor_); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L99-L107 + inline Real& operator() (MatrixIndexT r, MatrixIndexT c) { + // CPU only + return tensor_.accessor()[r][c]; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L112-L120 + inline const Real operator() (MatrixIndexT r, MatrixIndexT c) const { + return tensor_.index({Slice(r), Slice(c)}).item().template to(); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L138-L141 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L859-L898 + template + void CopyFromMat(const MatrixBase & M, + MatrixTransposeType trans = kNoTrans) { + auto src = M.tensor_; + if (trans == kTrans) + src = src.transpose(1, 0); + tensor_.index_put_({Slice(), Slice()}, src); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L186-L191 + inline const SubVector Row(MatrixIndexT i) const { + return SubVector(*this, i); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L208-L211 + inline SubMatrix RowRange(const MatrixIndexT row_offset, + const MatrixIndexT num_rows) const { + return SubMatrix(*this, row_offset, num_rows, 0, NumCols()); + } + +protected: + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L749-L753 + explicit MatrixBase(): tensor_(torch::empty({0, 0})) { + KALDI_ASSERT_IS_FLOATING_TYPE(Real); + } +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L781-L784 +template +class Matrix : public MatrixBase { +public: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L786-L787 + Matrix() : MatrixBase() {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L789-L793 + Matrix(const MatrixIndexT r, const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride) + : MatrixBase() { Resize(r, c, resize_type, stride_type); } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L808-L811 + explicit Matrix(const MatrixBase & M, + MatrixTransposeType trans = kNoTrans) + : MatrixBase(trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) + {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L816-L819 + template + explicit Matrix(const MatrixBase & M, + MatrixTransposeType trans = kNoTrans) + : MatrixBase(trans == kNoTrans ? M.tensor_ : M.tensor_.transpose(1, 0)) + {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L859-L874 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.cc#L817-L857 + void Resize(const MatrixIndexT r, + const MatrixIndexT c, + MatrixResizeType resize_type = kSetZero, + MatrixStrideType stride_type = kDefaultStride) { + auto &tensor_ = MatrixBase::tensor_; + switch(resize_type) { + case kSetZero: + tensor_.resize_({r, c}).zero_(); + break; + case kUndefined: + tensor_.resize_({r, c}); + break; + case kCopyData: + auto tmp = tensor_; + auto tmp_rows = tmp.size(0); + auto tmp_cols = tmp.size(1); + tensor_.resize_({r, c}).zero_(); + auto rows = Slice(None, r < tmp_rows ? r : tmp_rows); + auto cols = Slice(None, c < tmp_cols ? c : tmp_cols); + tensor_.index_put_({rows, cols}, tmp.index({rows, cols})); + break; + } + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L876-L883 + Matrix &operator = (const MatrixBase &other) { + if (MatrixBase::NumRows() != other.NumRows() || + MatrixBase::NumCols() != other.NumCols()) + Resize(other.NumRows(), other.NumCols(), kUndefined); + MatrixBase::CopyFromMat(other); + return *this; + } +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L940-L948 +template +class SubMatrix : public MatrixBase { +public: + SubMatrix(const MatrixBase& T, + const MatrixIndexT ro, // row offset, 0 < ro < NumRows() + const MatrixIndexT r, // number of rows, r > 0 + const MatrixIndexT co, // column offset, 0 < co < NumCols() + const MatrixIndexT c) // number of columns, c > 0 + : MatrixBase(T.tensor_.index({Slice(ro, ro+r), Slice(co, co+c)})) {} +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-matrix.h#L1059-L1060 +template +std::ostream & operator << (std::ostream & Out, const MatrixBase & M) { + Out << M.tensor_; + return Out; +} + +} // namespace kaldi + +#endif diff --git a/torchaudio/csrc/kaldi/matrix/kaldi-vector.cc b/torchaudio/csrc/kaldi/matrix/kaldi-vector.cc new file mode 100644 index 00000000000..114a74666cd --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/kaldi-vector.cc @@ -0,0 +1,43 @@ +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace { + +template +void assert_vector_shape(const torch::Tensor &tensor_); + +template<> +void assert_vector_shape(const torch::Tensor &tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat32); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +template<> +void assert_vector_shape(const torch::Tensor &tensor_) { + TORCH_INTERNAL_ASSERT(tensor_.ndimension() == 1); + TORCH_INTERNAL_ASSERT(tensor_.dtype() == torch::kFloat64); + TORCH_CHECK(tensor_.device().is_cpu(), "Input tensor has to be on CPU."); +} + +} // namespace + +namespace kaldi { + +template +VectorBase::VectorBase(torch::Tensor tensor) : + tensor_(tensor), + data_(tensor.data_ptr()) +{ + assert_vector_shape(tensor_); +}; + +template +VectorBase::VectorBase() : VectorBase(torch::empty({0})) {} + +template class Vector; +template class Vector; +template class VectorBase; +template class VectorBase; + +} // namespace kaldi diff --git a/torchaudio/csrc/kaldi/matrix/kaldi-vector.h b/torchaudio/csrc/kaldi/matrix/kaldi-vector.h new file mode 100644 index 00000000000..59ab6349831 --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/kaldi-vector.h @@ -0,0 +1,281 @@ +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h + +#ifndef KALDI_MATRIX_KALDI_VECTOR_H_ +#define KALDI_MATRIX_KALDI_VECTOR_H_ + +#include +#include "matrix/matrix-common.h" + +using namespace torch::indexing; + +namespace kaldi { + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L36-L40 +template +class VectorBase { +public: + //////////////////////////////////////////////////////////////////////////////// + // PyTorch-specific things + //////////////////////////////////////////////////////////////////////////////// + torch::Tensor tensor_; + + /// Construct VectorBase which is an interface to an existing torch::Tensor object. + VectorBase(torch::Tensor tensor); + + //////////////////////////////////////////////////////////////////////////////// + // Kaldi-compatible methods + //////////////////////////////////////////////////////////////////////////////// + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L42-L43 + void SetZero() { Set(0); } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L48-L49 + void Set(Real f) { tensor_.index_put_({"..."}, f); } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L62-L63 + inline MatrixIndexT Dim() const { return tensor_.numel(); }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L68-L69 + inline Real* Data() { return data_; } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L71-L72 + inline const Real* Data() const { return data_; } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L74-L79 + inline Real operator() (MatrixIndexT i) const { + return data_[i]; + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L81-L86 + inline Real& operator() (MatrixIndexT i) { + return tensor_.accessor()[i]; + }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L88-L95 + SubVector Range(const MatrixIndexT o, const MatrixIndexT l) { + return SubVector(*this, o, l); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L97-L105 + const SubVector Range(const MatrixIndexT o, + const MatrixIndexT l) const { + return SubVector(*this, o, l); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L107-L108 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L226-L233 + void CopyFromVec(const VectorBase &v) { + TORCH_INTERNAL_ASSERT(tensor_.sizes() == v.tensor_.sizes()); + tensor_.copy_(v.tensor_); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L137-L139 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L816-L832 + void ApplyFloor(Real floor_val, MatrixIndexT *floored_count = nullptr) { + auto index = tensor_ < floor_val; + auto tmp = tensor_.index_put_({index}, floor_val); + if (floored_count) { + *floored_count = index.sum().item().template to(); + } + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L164-L165 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L449-L479 + void ApplyPow(Real power) { + tensor_.pow_(power); + TORCH_INTERNAL_ASSERT(! tensor_.isnan().sum().item().template to()); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L181-L184 + template + void AddVec(const Real alpha, const VectorBase &v) { + tensor_ += alpha * v.tensor_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L186-L187 + void AddVec2(const Real alpha, const VectorBase &v) { + tensor_ += alpha * (v.tensor_.square()); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L196-L198 + void AddMatVec(const Real alpha, const MatrixBase &M, + const MatrixTransposeType trans, const VectorBase &v, + const Real beta) { // **beta previously defaulted to 0.0** + auto mat = M.tensor_; + if (trans == kTrans) { + mat = mat.transpose(1, 0); + } + tensor_.addmv_(mat, v.tensor_, beta, alpha); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L221-L222 + void MulElements(const VectorBase &v) { + tensor_ *= v.tensor_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L233-L234 + void Add(Real c) { tensor_ += c; } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L236-L239 + void AddVecVec(Real alpha, const VectorBase &v, + const VectorBase &r, Real beta) { + tensor_ = beta * tensor_ + alpha * v.tensor_ * r.tensor_; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L246-L247 + void Scale(Real alpha) { tensor_ *= alpha; } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L305-L306 + Real Min() const { + if (tensor_.numel()) { + return tensor_.min().item().template to(); + } + return std::numeric_limits::infinity(); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L308-L310 + Real Min(MatrixIndexT *index) const { + TORCH_INTERNAL_ASSERT(tensor_.numel()); + torch::Tensor value, ind; + std::tie(value, ind) = tensor_.min(0); + *index = ind.item().to(); + return value.item().to(); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L312-L313 + Real Sum() const { return tensor_.sum().item().template to(); }; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L320-L321 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L718-L736 + void AddRowSumMat(Real alpha, const MatrixBase &M, Real beta = 1.0) { + Vector ones(M.NumRows()); + ones.Set(1.0); + this->AddMatVec(alpha, M, kTrans, ones, beta); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L323-L324 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L738-L757 + void AddColSumMat(Real alpha, const MatrixBase &M, Real beta = 1.0) { + Vector ones(M.NumCols()); + ones.Set(1.0); + this->AddMatVec(alpha, M, kNoTrans, ones, beta); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L326-L330 + void AddDiagMat2(Real alpha, const MatrixBase &M, + MatrixTransposeType trans = kNoTrans, Real beta = 1.0) { + auto mat = M.tensor_; + if (trans == kNoTrans) { + tensor_ = beta * tensor_ + torch::diag(torch::mm(mat, mat.transpose(1, 0))); + } else { + tensor_ = beta * tensor_ + torch::diag(torch::mm(mat.transpose(1, 0), mat)); + } + } + +protected: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L362-L365 + explicit VectorBase(); + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L378-L379 + Real* data_; + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L382 + KALDI_DISALLOW_COPY_AND_ASSIGN(VectorBase); +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L385-L390 +template +class Vector : public VectorBase { +public: + //////////////////////////////////////////////////////////////////////////////// + // PyTorch-compatibility things + //////////////////////////////////////////////////////////////////////////////// + /// Construct VectorBase which is an interface to an existing torch::Tensor object. + Vector(torch::Tensor tensor) : VectorBase(tensor) {}; + + //////////////////////////////////////////////////////////////////////////////// + // Kaldi-compatible methods + //////////////////////////////////////////////////////////////////////////////// + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L392-L393 + Vector(): VectorBase() {}; + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L395-L399 + explicit Vector(const MatrixIndexT s, + MatrixResizeType resize_type = kSetZero) + : VectorBase() { Resize(s, resize_type); } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L406-L410 + // Note: unlike the original implementation, this is "explicit". + explicit Vector(const Vector &v) : VectorBase(v.tensor_.clone()) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L412-L416 + explicit Vector(const VectorBase &v) : VectorBase(v.tensor_.clone()) {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L434-L435 + void Swap(Vector *other) { + auto tmp = VectorBase::tensor_; + this->tensor_ = other->tensor_; + other->tensor_ = tmp; + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L444-L451 + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.cc#L189-L223 + void Resize(MatrixIndexT length, MatrixResizeType resize_type = kSetZero) { + auto &tensor_ = this->tensor_; + switch(resize_type) { + case kSetZero: + tensor_.resize_({length}).zero_(); + break; + case kUndefined: + tensor_.resize_({length}); + break; + case kCopyData: + auto tmp = tensor_; + auto tmp_numel = tensor_.numel(); + tensor_.resize_({length}).zero_(); + auto numel = Slice(length < tmp_numel ? length : tmp_numel); + tensor_.index_put_({numel}, tmp.index({numel})); + break; + } + // data_ptr() causes compiler error + this->data_ = static_cast(tensor_.data_ptr()); + } + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L463-L468 + Vector &operator = (const VectorBase &other) { + Resize(other.Dim(), kUndefined); + this->CopyFromVec(other); + return *this; + } +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L482-L485 +template +class SubVector : public VectorBase { +public: + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L487-L499 + SubVector(const VectorBase &t, const MatrixIndexT origin, + const MatrixIndexT length) + : VectorBase(t.tensor_.index({Slice(origin, origin + length)})) + {} + + // https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L524-L528 + SubVector(const MatrixBase &matrix, MatrixIndexT row) + : VectorBase(matrix.tensor_.index({row})) + {} +}; + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L540-L543 +template +std::ostream & operator << (std::ostream & out, const VectorBase & v) { + out << v.tensor_; + return out; +} + +// https://github.com/kaldi-asr/kaldi/blob/7fb716aa0f56480af31514c7e362db5c9f787fd4/src/matrix/kaldi-vector.h#L573-L575 +template +Real VecVec(const VectorBase &v1, const VectorBase &v2) { + return torch::dot(v1.tensor_, v2.tensor_).item().template to(); +} + +} // namespace kaldi + +#endif diff --git a/torchaudio/csrc/kaldi/matrix/matrix-common.h b/torchaudio/csrc/kaldi/matrix/matrix-common.h new file mode 100644 index 00000000000..f7047d71ca5 --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/matrix-common.h @@ -0,0 +1,111 @@ +// matrix/matrix-common.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_MATRIX_MATRIX_COMMON_H_ +#define KALDI_MATRIX_MATRIX_COMMON_H_ + +// This file contains some #includes, forward declarations +// and typedefs that are needed by all the main header +// files in this directory. + +#include "base/kaldi-common.h" + +namespace kaldi { +// this enums equal to CblasTrans and CblasNoTrans constants from CBLAS library +// we are writing them as literals because we don't want to include here matrix/kaldi-blas.h, +// which puts many symbols into global scope (like "real") via the header f2c.h +typedef enum { + kTrans = 112, // = CblasTrans + kNoTrans = 111 // = CblasNoTrans +} MatrixTransposeType; + +typedef enum { + kSetZero, + kUndefined, + kCopyData +} MatrixResizeType; + + +typedef enum { + kDefaultStride, + kStrideEqualNumCols, +} MatrixStrideType; + +typedef enum { + kTakeLower, + kTakeUpper, + kTakeMean, + kTakeMeanAndCheck +} SpCopyType; + +template class VectorBase; +template class Vector; +template class SubVector; +template class MatrixBase; +template class SubMatrix; +template class Matrix; +template class SpMatrix; +template class TpMatrix; +template class PackedMatrix; +template class SparseMatrix; + +// these are classes that won't be defined in this +// directory; they're mostly needed for friend declarations. +template class CuMatrixBase; +template class CuSubMatrix; +template class CuMatrix; +template class CuVectorBase; +template class CuSubVector; +template class CuVector; +template class CuPackedMatrix; +template class CuSpMatrix; +template class CuTpMatrix; +template class CuSparseMatrix; + +class CompressedMatrix; +class GeneralMatrix; + +/// This class provides a way for switching between double and float types. +template class OtherReal { }; // useful in reading+writing routines + // to switch double and float. +/// A specialized class for switching from float to double. +template<> class OtherReal { + public: + typedef double Real; +}; +/// A specialized class for switching from double to float. +template<> class OtherReal { + public: + typedef float Real; +}; + + +typedef int32 MatrixIndexT; +typedef int32 SignedMatrixIndexT; +typedef uint32 UnsignedMatrixIndexT; + +// If you want to use size_t for the index type, do as follows instead: +//typedef size_t MatrixIndexT; +//typedef ssize_t SignedMatrixIndexT; +//typedef size_t UnsignedMatrixIndexT; + +} + + + +#endif // KALDI_MATRIX_MATRIX_COMMON_H_ diff --git a/torchaudio/csrc/kaldi/matrix/matrix-functions-inl.h b/torchaudio/csrc/kaldi/matrix/matrix-functions-inl.h new file mode 100644 index 00000000000..9fac851efd3 --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/matrix-functions-inl.h @@ -0,0 +1,56 @@ +// matrix/matrix-functions-inl.h + +// Copyright 2009-2011 Microsoft Corporation +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ +#define KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ + +namespace kaldi { + +//! ComplexMul implements, inline, the complex multiplication b *= a. +template inline void ComplexMul(const Real &a_re, const Real &a_im, + Real *b_re, Real *b_im) { + Real tmp_re = (*b_re * a_re) - (*b_im * a_im); + *b_im = *b_re * a_im + *b_im * a_re; + *b_re = tmp_re; +} + +template inline void ComplexAddProduct(const Real &a_re, const Real &a_im, + const Real &b_re, const Real &b_im, + Real *c_re, Real *c_im) { + *c_re += b_re*a_re - b_im*a_im; + *c_im += b_re*a_im + b_im*a_re; +} + + +template inline void ComplexImExp(Real x, Real *a_re, Real *a_im) { + *a_re = std::cos(x); + *a_im = std::sin(x); +} + + +} // end namespace kaldi + + +#endif // KALDI_MATRIX_MATRIX_FUNCTIONS_INL_H_ + diff --git a/torchaudio/csrc/kaldi/matrix/matrix-functions.h b/torchaudio/csrc/kaldi/matrix/matrix-functions.h new file mode 100644 index 00000000000..ca50ddda7c8 --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/matrix-functions.h @@ -0,0 +1,174 @@ +// matrix/matrix-functions.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc.; Jan Silovsky; +// Yanmin Qian; 1991 Henrique (Rico) Malvar (*) +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_MATRIX_FUNCTIONS_H_ +#define KALDI_MATRIX_MATRIX_FUNCTIONS_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// @addtogroup matrix_funcs_misc +/// @{ + +/** The function ComplexFft does an Fft on the vector argument v. + v is a vector of even dimension, interpreted for both input + and output as a vector of complex numbers i.e. + \f[ v = ( re_0, im_0, re_1, im_1, ... ) \f] + + If "forward == true" this routine does the Discrete Fourier Transform + (DFT), i.e.: + \f[ vout[m] \leftarrow \sum_{n = 0}^{N-1} vin[i] exp( -2pi m n / N ) \f] + + If "backward" it does the Inverse Discrete Fourier Transform (IDFT) + *WITHOUT THE FACTOR 1/N*, + i.e.: + \f[ vout[m] <-- \sum_{n = 0}^{N-1} vin[i] exp( 2pi m n / N ) \f] + [note the sign difference on the 2 pi for the backward one.] + + Note that this is the definition of the FT given in most texts, but + it differs from the Numerical Recipes version in which the forward + and backward algorithms are flipped. + + Note that you would have to multiply by 1/N after the IDFT to get + back to where you started from. We don't do this because + in some contexts, the transform is made symmetric by multiplying + by sqrt(N) in both passes. The user can do this by themselves. + + See also SplitRadixComplexFft, declared in srfft.h, which is more efficient + but only works if the length of the input is a power of 2. + */ +template void ComplexFft (VectorBase *v, bool forward, Vector *tmp_work = NULL); + +/// ComplexFt is the same as ComplexFft but it implements the Fourier +/// transform in an inefficient way. It is mainly included for testing purposes. +/// See comment for ComplexFft to describe the input and outputs and what it does. +template void ComplexFt (const VectorBase &in, + VectorBase *out, bool forward); + +/// RealFft is a fourier transform of real inputs. Internally it uses +/// ComplexFft. The input dimension N must be even. If forward == true, +/// it transforms from a sequence of N real points to its complex fourier +/// transform; otherwise it goes in the reverse direction. If you call it +/// in the forward and then reverse direction and multiply by 1.0/N, you +/// will get back the original data. +/// The interpretation of the complex-FFT data is as follows: the array +/// is a sequence of complex numbers C_n of length N/2 with (real, im) format, +/// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. +/// See also SplitRadixRealFft, declared in srfft.h, which is more efficient +/// but only works if the length of the input is a power of 2. + +template void RealFft (VectorBase *v, bool forward); + + +/// RealFt has the same input and output format as RealFft above, but it is +/// an inefficient implementation included for testing purposes. +template void RealFftInefficient (VectorBase *v, bool forward); + +/// ComputeDctMatrix computes a matrix corresponding to the DCT, such that +/// M * v equals the DCT of vector v. M must be square at input. +/// This is the type = III DCT with normalization, corresponding to the +/// following equations, where x is the signal and X is the DCT: +/// X_0 = 1/sqrt(2*N) \sum_{n = 0}^{N-1} x_n +/// X_k = 1/sqrt(N) \sum_{n = 0}^{N-1} x_n cos( \pi/N (n + 1/2) k ) +/// This matrix's transpose is its own inverse, so transposing this +/// matrix will give the inverse DCT. +/// Caution: the type III DCT is generally known as the "inverse DCT" (with the +/// type II being the actual DCT), so this function is somewhatd mis-named. It +/// was probably done this way for HTK compatibility. We don't change it +/// because it was this way from the start and changing it would affect the +/// feature generation. + +template void ComputeDctMatrix(Matrix *M); + + +/// ComplexMul implements, inline, the complex multiplication b *= a. +template inline void ComplexMul(const Real &a_re, const Real &a_im, + Real *b_re, Real *b_im); + +/// ComplexMul implements, inline, the complex operation c += (a * b). +template inline void ComplexAddProduct(const Real &a_re, const Real &a_im, + const Real &b_re, const Real &b_im, + Real *c_re, Real *c_im); + + +/// ComplexImExp implements a <-- exp(i x), inline. +template inline void ComplexImExp(Real x, Real *a_re, Real *a_im); + + + +/** + ComputePCA does a PCA computation, using either outer products + or inner products, whichever is more efficient. Let D be + the dimension of the data points, N be the number of data + points, and G be the PCA dimension we want to retain. We assume + G <= N and G <= D. + + @param X [in] An N x D matrix. Each row of X is a point x_i. + @param U [out] A G x D matrix. Each row of U is a basis element u_i. + @param A [out] An N x D matrix, or NULL. Each row of A is a set of coefficients + in the basis for a point x_i, so A(i, g) is the coefficient of u_i + in x_i. + @param print_eigs [in] If true, prints out diagnostic information about the + eigenvalues. + @param exact [in] If true, does the exact computation; if false, does + a much faster (but almost exact) computation based on the Lanczos + method. +*/ + +template +void ComputePca(const MatrixBase &X, + MatrixBase *U, + MatrixBase *A, + bool print_eigs = false, + bool exact = true); + + + +// This function does: *plus += max(0, a b^T), +// *minus += max(0, -(a b^T)). +template +void AddOuterProductPlusMinus(Real alpha, + const VectorBase &a, + const VectorBase &b, + MatrixBase *plus, + MatrixBase *minus); + +template +inline void AssertSameDim(const MatrixBase &mat1, const MatrixBase &mat2) { + KALDI_ASSERT(mat1.NumRows() == mat2.NumRows() + && mat1.NumCols() == mat2.NumCols()); +} + + +/// @} end of "addtogroup matrix_funcs_misc" + +} // end namespace kaldi + +#include "matrix/matrix-functions-inl.h" + + +#endif diff --git a/torchaudio/csrc/kaldi/matrix/matrix-lib.h b/torchaudio/csrc/kaldi/matrix/matrix-lib.h new file mode 100644 index 00000000000..4fb9e1b16ef --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/matrix-lib.h @@ -0,0 +1,38 @@ +// matrix/matrix-lib.h + +// Copyright 2009-2011 Ondrej Glembek; Microsoft Corporation; Haihua Xu + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +// Include everything from this directory. +// These files include other stuff that we need. +#ifndef KALDI_MATRIX_MATRIX_LIB_H_ +#define KALDI_MATRIX_MATRIX_LIB_H_ + +#include "base/kaldi-common.h" +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" +// #include "matrix/sp-matrix.h" +// #include "matrix/tp-matrix.h" +#include "matrix/matrix-functions.h" +#include "matrix/srfft.h" +#include "matrix/compressed-matrix.h" +// #include "matrix/sparse-matrix.h" +#include "matrix/optimization.h" +// #include "matrix/numpy-array.h" + +#endif + diff --git a/torchaudio/csrc/kaldi/matrix/optimization.h b/torchaudio/csrc/kaldi/matrix/optimization.h new file mode 100644 index 00000000000..66309acaad5 --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/optimization.h @@ -0,0 +1,248 @@ +// matrix/optimization.h + +// Copyright 2012 Johns Hopkins University (author: Daniel Povey) +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// (*) incorporates, with permission, FFT code from his book +// "Signal Processing with Lapped Transforms", Artech, 1992. + + + +#ifndef KALDI_MATRIX_OPTIMIZATION_H_ +#define KALDI_MATRIX_OPTIMIZATION_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + + +/// @addtogroup matrix_optimization +/// @{ + +struct LinearCgdOptions { + int32 max_iters; // Maximum number of iters (if >= 0). + BaseFloat max_error; // Maximum 2-norm of the residual A x - b (convergence + // test) + // Every time the residual 2-norm decreases by this recompute_residual_factor + // since the last time it was computed from scratch, recompute it from + // scratch. This helps to keep the computed residual accurate even in the + // presence of roundoff. + BaseFloat recompute_residual_factor; + + LinearCgdOptions(): max_iters(-1), + max_error(0.0), + recompute_residual_factor(0.01) { } +}; + +/* + This function uses linear conjugate gradient descent to approximately solve + the system A x = b. The value of x at entry corresponds to the initial guess + of x. The algorithm continues until the number of iterations equals b.Dim(), + or until the 2-norm of (A x - b) is <= max_error, or until the number of + iterations equals max_iter, whichever happens sooner. It is a requirement + that A be positive definite. + It returns the number of iterations that were actually executed (this is + useful for testing purposes). +*/ +template +int32 LinearCgd(const LinearCgdOptions &opts, + const SpMatrix &A, const VectorBase &b, + VectorBase *x); + + + + + + +/** + This is an implementation of L-BFGS. It pushes responsibility for + determining when to stop, onto the user. There is no call-back here: + everything is done via calls to the class itself (see the example in + matrix-lib-test.cc). This does not implement constrained L-BFGS, but it will + handle constrained problems correctly as long as the function approaches + +infinity (or -infinity for maximization problems) when it gets close to the + bound of the constraint. In these types of problems, you just let the + function value be +infinity for minimization problems, or -infinity for + maximization problems, outside these bounds). +*/ + +struct LbfgsOptions { + bool minimize; // if true, we're minimizing, else maximizing. + int m; // m is the number of stored vectors L-BFGS keeps. + float first_step_learning_rate; // The very first step of L-BFGS is + // like gradient descent. If you want to configure the size of that step, + // you can do it using this variable. + float first_step_length; // If this variable is >0.0, it overrides + // first_step_learning_rate; on the first step we choose an approximate + // Hessian that is the multiple of the identity that would generate this + // step-length, or 1.0 if the gradient is zero. + float first_step_impr; // If this variable is >0.0, it overrides + // first_step_learning_rate; on the first step we choose an approximate + // Hessian that is the multiple of the identity that would generate this + // amount of objective function improvement (assuming the "real" objf + // was linear). + float c1; // A constant in Armijo rule = Wolfe condition i) + float c2; // A constant in Wolfe condition ii) + float d; // An amount > 1.0 (default 2.0) that we initially multiply or + // divide the step length by, in the line search. + int max_line_search_iters; // after this many iters we restart L-BFGS. + int avg_step_length; // number of iters to avg step length over, in + // RecentStepLength(). + + LbfgsOptions (bool minimize = true): + minimize(minimize), + m(10), + first_step_learning_rate(1.0), + first_step_length(0.0), + first_step_impr(0.0), + c1(1.0e-04), + c2(0.9), + d(2.0), + max_line_search_iters(50), + avg_step_length(4) { } +}; + +template +class OptimizeLbfgs { + public: + /// Initializer takes the starting value of x. + OptimizeLbfgs(const VectorBase &x, + const LbfgsOptions &opts); + + /// This returns the value of the variable x that has the best objective + /// function so far, and the corresponding objective function value if + /// requested. This would typically be called only at the end. + const VectorBase& GetValue(Real *objf_value = NULL) const; + + /// This returns the value at which the function wants us + /// to compute the objective function and gradient. + const VectorBase& GetProposedValue() const { return new_x_; } + + /// Returns the average magnitude of the last n steps (but not + /// more than the number we have stored). Before we have taken + /// any steps, returns +infinity. Note: if the most recent + /// step length was 0, it returns 0, regardless of the other + /// step lengths. This makes it suitable as a convergence test + /// (else we'd generate NaN's). + Real RecentStepLength() const; + + /// The user calls this function to provide the class with the + /// function and gradient info at the point GetProposedValue(). + /// If this point is outside the constraints you can set function_value + /// to {+infinity,-infinity} for {minimization,maximization} problems. + /// In this case the gradient, and also the second derivative (if you call + /// the second overloaded version of this function) will be ignored. + void DoStep(Real function_value, + const VectorBase &gradient); + + /// The user can call this version of DoStep() if it is desired to set some + /// kind of approximate Hessian on this iteration. Note: it is a prerequisite + /// that diag_approx_2nd_deriv must be strictly positive (minimizing), or + /// negative (maximizing). + void DoStep(Real function_value, + const VectorBase &gradient, + const VectorBase &diag_approx_2nd_deriv); + + private: + KALDI_DISALLOW_COPY_AND_ASSIGN(OptimizeLbfgs); + + + // The following variable says what stage of the computation we're at. + // Refer to Algorithm 7.5 (L-BFGS) of Nodecdal & Wright, "Numerical + // Optimization", 2nd edition. + // kBeforeStep means we're about to do + /// "compute p_k <-- - H_k \delta f_k" (i.e. Algorithm 7.4). + // kWithinStep means we're at some point within line search; note + // that line search is iterative so we can stay in this state more + // than one time on each iteration. + enum ComputationState { + kBeforeStep, + kWithinStep, // This means we're within the step-size computation, and + // have not yet done the 1st function evaluation. + }; + + inline MatrixIndexT Dim() { return x_.Dim(); } + inline MatrixIndexT M() { return opts_.m; } + SubVector Y(MatrixIndexT i) { + return SubVector(data_, (i % M()) * 2); // vector y_i + } + SubVector S(MatrixIndexT i) { + return SubVector(data_, (i % M()) * 2 + 1); // vector s_i + } + // The following are subroutines within DoStep(): + bool AcceptStep(Real function_value, + const VectorBase &gradient); + void Restart(const VectorBase &x, + Real function_value, + const VectorBase &gradient); + void ComputeNewDirection(Real function_value, + const VectorBase &gradient); + void ComputeHifNeeded(const VectorBase &gradient); + void StepSizeIteration(Real function_value, + const VectorBase &gradient); + void RecordStepLength(Real s); + + + LbfgsOptions opts_; + SignedMatrixIndexT k_; // Iteration number, starts from zero. Gets set back to zero + // when we restart. + + ComputationState computation_state_; + bool H_was_set_; // True if the user specified H_; if false, + // we'll use a heuristic to estimate it. + + + Vector x_; // current x. + Vector new_x_; // the x proposed in the line search. + Vector best_x_; // the x with the best objective function so far + // (either the same as x_ or something in the current line search.) + Vector deriv_; // The most recently evaluated derivative-- at x_k. + Vector temp_; + Real f_; // The function evaluated at x_k. + Real best_f_; // the best objective function so far. + Real d_; // a number d > 1.0, but during an iteration we may decrease this, when + // we switch between armijo and wolfe failures. + + int num_wolfe_i_failures_; // the num times we decreased step size. + int num_wolfe_ii_failures_; // the num times we increased step size. + enum { kWolfeI, kWolfeII, kNone } last_failure_type_; // last type of step-search + // failure on this iter. + + Vector H_; // Current inverse-Hessian estimate. May be computed by this class itself, + // or provided by user using 2nd form of SetGradientInfo(). + Matrix data_; // dimension (m*2) x dim. Even rows store + // gradients y_i, odd rows store steps s_i. + Vector rho_; // dimension m; rho_(m) = 1/(y_m^T s_m), Eq. 7.17. + + std::vector step_lengths_; // The step sizes we took on the last + // (up to m) iterations; these are not stored in a rotating buffer but + // are shifted by one each time (this is more convenient when we + // restart, as we keep this info past restarting). + + +}; + +/// @} + + +} // end namespace kaldi + + + +#endif + diff --git a/torchaudio/csrc/kaldi/matrix/srfft.h b/torchaudio/csrc/kaldi/matrix/srfft.h new file mode 100644 index 00000000000..98ff782a84a --- /dev/null +++ b/torchaudio/csrc/kaldi/matrix/srfft.h @@ -0,0 +1,141 @@ +// matrix/srfft.h + +// Copyright 2009-2011 Microsoft Corporation; Go Vivace Inc. +// 2014 Daniel Povey +// +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +// +// This file includes a modified version of code originally published in Malvar, +// H., "Signal processing with lapped transforms, " Artech House, Inc., 1992. The +// current copyright holder of the original code, Henrique S. Malvar, has given +// his permission for the release of this modified version under the Apache +// License v2.0. + +#ifndef KALDI_MATRIX_SRFFT_H_ +#define KALDI_MATRIX_SRFFT_H_ + +#include "matrix/kaldi-vector.h" +#include "matrix/kaldi-matrix.h" + +namespace kaldi { + +/// @addtogroup matrix_funcs_misc +/// @{ + + +// This class is based on code by Henrique (Rico) Malvar, from his book +// "Signal Processing with Lapped Transforms" (1992). Copied with +// permission, optimized by Go Vivace Inc., and converted into C++ by +// Microsoft Corporation +// This is a more efficient way of doing the complex FFT than ComplexFft +// (declared in matrix-functios.h), but it only works for powers of 2. +// Note: in multi-threaded code, you would need to have one of these objects per +// thread, because multiple calls to Compute in parallel would not work. +template +class SplitRadixComplexFft { + public: + typedef MatrixIndexT Integer; + + // N is the number of complex points (must be a power of two, or this + // will crash). Note that the constructor does some work so it's best to + // initialize the object once and do the computation many times. + SplitRadixComplexFft(Integer N); + + // Copy constructor + SplitRadixComplexFft(const SplitRadixComplexFft &other); + + // Does the FFT computation, given pointers to the real and + // imaginary parts. If "forward", do the forward FFT; else + // do the inverse FFT (without the 1/N factor). + // xr and xi are pointers to zero-based arrays of size N, + // containing the real and imaginary parts + // respectively. + void Compute(Real *xr, Real *xi, bool forward) const; + + // This version of Compute takes a single array of size N*2, + // containing [ r0 im0 r1 im1 ... ]. Otherwise its behavior is the + // same as the version above. + void Compute(Real *x, bool forward); + + + // This version of Compute is const; it operates on an array of size N*2 + // containing [ r0 im0 r1 im1 ... ], but it uses the argument "temp_buffer" as + // temporary storage instead of a class-member variable. It will allocate it if + // needed. + void Compute(Real *x, bool forward, std::vector *temp_buffer) const; + + ~SplitRadixComplexFft(); + + protected: + // temp_buffer_ is allocated only if someone calls Compute with only one Real* + // argument and we need a temporary buffer while creating interleaved data. + std::vector temp_buffer_; + private: + void ComputeTables(); + void ComputeRecursive(Real *xr, Real *xi, Integer logn) const; + void BitReversePermute(Real *x, Integer logn) const; + + Integer N_; + Integer logn_; // log(N) + + Integer *brseed_; + // brseed is Evans' seed table, ref: (Ref: D. M. W. + // Evans, "An improved digit-reversal permutation algorithm ...", + // IEEE Trans. ASSP, Aug. 1987, pp. 1120-1125). + Real **tab_; // Tables of butterfly coefficients. + + // Disallow assignment. + SplitRadixComplexFft &operator =(const SplitRadixComplexFft &other); +}; + +template +class SplitRadixRealFft: private SplitRadixComplexFft { + public: + SplitRadixRealFft(MatrixIndexT N): // will fail unless N>=4 and N is a power of 2. + SplitRadixComplexFft (N/2), N_(N) { } + + // Copy constructor + SplitRadixRealFft(const SplitRadixRealFft &other): + SplitRadixComplexFft(other), N_(other.N_) { } + + /// If forward == true, this function transforms from a sequence of N real points to its complex fourier + /// transform; otherwise it goes in the reverse direction. If you call it + /// in the forward and then reverse direction and multiply by 1.0/N, you + /// will get back the original data. + /// The interpretation of the complex-FFT data is as follows: the array + /// is a sequence of complex numbers C_n of length N/2 with (real, im) format, + /// i.e. [real0, real_{N/2}, real1, im1, real2, im2, real3, im3, ...]. + void Compute(Real *x, bool forward); + + + /// This is as the other Compute() function, but it is a const version that + /// uses a user-supplied buffer. + void Compute(Real *x, bool forward, std::vector *temp_buffer) const; + + private: + // Disallow assignment. + SplitRadixRealFft &operator =(const SplitRadixRealFft &other); + int N_; +}; + + +/// @} end of "addtogroup matrix_funcs_misc" + +} // end namespace kaldi + + +#endif + diff --git a/torchaudio/csrc/kaldi/register.cpp b/torchaudio/csrc/kaldi/register.cpp new file mode 100644 index 00000000000..d24e04640bc --- /dev/null +++ b/torchaudio/csrc/kaldi/register.cpp @@ -0,0 +1,10 @@ +#include + +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + ////////////////////////////////////////////////////////////////////////////// + // kaldi.h + ////////////////////////////////////////////////////////////////////////////// + m.def( + "torchaudio::kaldi_ComputeKaldiPitch", + &torchaudio::kaldi::ComputeKaldiPitch); +} diff --git a/torchaudio/csrc/kaldi/util/common-utils.h b/torchaudio/csrc/kaldi/util/common-utils.h new file mode 100644 index 00000000000..48d199e97ec --- /dev/null +++ b/torchaudio/csrc/kaldi/util/common-utils.h @@ -0,0 +1,31 @@ +// util/common-utils.h + +// Copyright 2009-2011 Microsoft Corporation + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at + +// http://www.apache.org/licenses/LICENSE-2.0 + +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. +#ifndef KALDI_UTIL_COMMON_UTILS_H_ +#define KALDI_UTIL_COMMON_UTILS_H_ + +#include "base/kaldi-common.h" +#include "util/parse-options.h" +// #include "util/kaldi-io.h" +// #include "util/simple-io-funcs.h" +// #include "util/kaldi-holder.h" +// #include "util/kaldi-table.h" +// #include "util/table-types.h" +// #include "util/text-utils.h" + +#endif // KALDI_UTIL_COMMON_UTILS_H_ diff --git a/torchaudio/csrc/kaldi/util/parse-options.h b/torchaudio/csrc/kaldi/util/parse-options.h new file mode 100644 index 00000000000..6884a40c10e --- /dev/null +++ b/torchaudio/csrc/kaldi/util/parse-options.h @@ -0,0 +1,264 @@ +// util/parse-options.h + +// Copyright 2009-2011 Karel Vesely; Microsoft Corporation; +// Saarland University (Author: Arnab Ghoshal); +// Copyright 2012-2013 Frantisek Skala; Arnab Ghoshal + +// See ../../COPYING for clarification regarding multiple authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED +// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, +// MERCHANTABLITY OR NON-INFRINGEMENT. +// See the Apache 2 License for the specific language governing permissions and +// limitations under the License. + +#ifndef KALDI_UTIL_PARSE_OPTIONS_H_ +#define KALDI_UTIL_PARSE_OPTIONS_H_ + +#include +#include +#include + +#include "base/kaldi-common.h" +#include "itf/options-itf.h" + +namespace kaldi { + +/// The class ParseOptions is for parsing command-line options; see +/// \ref parse_options for more documentation. +class ParseOptions : public OptionsItf { + public: + explicit ParseOptions(const char *usage) : + print_args_(true), help_(false), usage_(usage), argc_(0), argv_(NULL), + prefix_(""), other_parser_(NULL) { +#if !defined(_MSC_VER) && !defined(__CYGWIN__) // This is just a convenient place to set the stderr to line + setlinebuf(stderr); // buffering mode, since it's called at program start. +#endif // This helps ensure different programs' output is not mixed up. + RegisterStandard("config", &config_, "Configuration file to read (this " + "option may be repeated)"); + RegisterStandard("print-args", &print_args_, + "Print the command line arguments (to stderr)"); + RegisterStandard("help", &help_, "Print out usage message"); + RegisterStandard("verbose", &g_kaldi_verbose_level, + "Verbose level (higher->more logging)"); + } + + /** + This is a constructor for the special case where some options are + registered with a prefix to avoid conflicts. The object thus created will + only be used temporarily to register an options class with the original + options parser (which is passed as the *other pointer) using the given + prefix. It should not be used for any other purpose, and the prefix must + not be the empty string. It seems to be the least bad way of implementing + options with prefixes at this point. + Example of usage is: + ParseOptions po; // original ParseOptions object + ParseOptions po_mfcc("mfcc", &po); // object with prefix. + MfccOptions mfcc_opts; + mfcc_opts.Register(&po_mfcc); + The options will now get registered as, e.g., --mfcc.frame-shift=10.0 + instead of just --frame-shift=10.0 + */ + ParseOptions(const std::string &prefix, OptionsItf *other); + + ~ParseOptions() {} + + // Methods from the interface + void Register(const std::string &name, + bool *ptr, const std::string &doc); + void Register(const std::string &name, + int32 *ptr, const std::string &doc); + void Register(const std::string &name, + uint32 *ptr, const std::string &doc); + void Register(const std::string &name, + float *ptr, const std::string &doc); + void Register(const std::string &name, + double *ptr, const std::string &doc); + void Register(const std::string &name, + std::string *ptr, const std::string &doc); + + /// If called after registering an option and before calling + /// Read(), disables that option from being used. Will crash + /// at runtime if that option had not been registered. + void DisableOption(const std::string &name); + + /// This one is used for registering standard parameters of all the programs + template + void RegisterStandard(const std::string &name, + T *ptr, const std::string &doc); + + /** + Parses the command line options and fills the ParseOptions-registered + variables. This must be called after all the variables were registered!!! + + Initially the variables have implicit values, + then the config file values are set-up, + finally the command line values given. + Returns the first position in argv that was not used. + [typically not useful: use NumParams() and GetParam(). ] + */ + int Read(int argc, const char *const *argv); + + /// Prints the usage documentation [provided in the constructor]. + void PrintUsage(bool print_command_line = false); + /// Prints the actual configuration of all the registered variables + void PrintConfig(std::ostream &os); + + /// Reads the options values from a config file. Must be called after + /// registering all options. This is usually used internally after the + /// standard --config option is used, but it may also be called from a + /// program. + void ReadConfigFile(const std::string &filename); + + /// Number of positional parameters (c.f. argc-1). + int NumArgs() const; + + /// Returns one of the positional parameters; 1-based indexing for argc/argv + /// compatibility. Will crash if param is not >=1 and <=NumArgs(). + std::string GetArg(int param) const; + + std::string GetOptArg(int param) const { + return (param <= NumArgs() ? GetArg(param) : ""); + } + + /// The following function will return a possibly quoted and escaped + /// version of "str", according to the current shell. Currently + /// this is just hardwired to bash. It's useful for debug output. + static std::string Escape(const std::string &str); + + private: + /// Template to register various variable types, + /// used for program-specific parameters + template + void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc); + + // Following functions do just the datatype-specific part of the job + /// Register boolean variable + void RegisterSpecific(const std::string &name, const std::string &idx, + bool *b, const std::string &doc, bool is_standard); + /// Register int32 variable + void RegisterSpecific(const std::string &name, const std::string &idx, + int32 *i, const std::string &doc, bool is_standard); + /// Register unsinged int32 variable + void RegisterSpecific(const std::string &name, const std::string &idx, + uint32 *u, + const std::string &doc, bool is_standard); + /// Register float variable + void RegisterSpecific(const std::string &name, const std::string &idx, + float *f, const std::string &doc, bool is_standard); + /// Register double variable [useful as we change BaseFloat type]. + void RegisterSpecific(const std::string &name, const std::string &idx, + double *f, const std::string &doc, bool is_standard); + /// Register string variable + void RegisterSpecific(const std::string &name, const std::string &idx, + std::string *s, const std::string &doc, + bool is_standard); + + /// Does the actual job for both kinds of parameters + /// Does the common part of the job for all datatypes, + /// then calls RegisterSpecific + template + void RegisterCommon(const std::string &name, + T *ptr, const std::string &doc, bool is_standard); + + /// Set option with name "key" to "value"; will crash if can't do it. + /// "has_equal_sign" is used to allow --x for a boolean option x, + /// and --y=, for a string option y. + bool SetOption(const std::string &key, const std::string &value, + bool has_equal_sign); + + bool ToBool(std::string str); + int32 ToInt(const std::string &str); + uint32 ToUint(const std::string &str); + float ToFloat(const std::string &str); + double ToDouble(const std::string &str); + + // maps for option variables + std::map bool_map_; + std::map int_map_; + std::map uint_map_; + std::map float_map_; + std::map double_map_; + std::map string_map_; + + /** + Structure for options' documentation + */ + struct DocInfo { + DocInfo() {} + DocInfo(const std::string &name, const std::string &usemsg) + : name_(name), use_msg_(usemsg), is_standard_(false) {} + DocInfo(const std::string &name, const std::string &usemsg, + bool is_standard) + : name_(name), use_msg_(usemsg), is_standard_(is_standard) {} + + std::string name_; + std::string use_msg_; + bool is_standard_; + }; + typedef std::map DocMapType; + DocMapType doc_map_; ///< map for the documentation + + bool print_args_; ///< variable for the implicit --print-args parameter + bool help_; ///< variable for the implicit --help parameter + std::string config_; ///< variable for the implicit --config parameter + std::vector positional_args_; + const char *usage_; + int argc_; + const char *const *argv_; + + /// These members are not normally used. They are only used when the object + /// is constructed with a prefix + std::string prefix_; + OptionsItf *other_parser_; + protected: + /// SplitLongArg parses an argument of the form --a=b, --a=, or --a, + /// and sets "has_equal_sign" to true if an equals-sign was parsed.. + /// this is needed in order to correctly allow --x for a boolean option + /// x, and --y= for a string option y, and to disallow --x= and --y. + void SplitLongArg(const std::string &in, std::string *key, + std::string *value, bool *has_equal_sign); + + void NormalizeArgName(std::string *str); +}; + +/// This template is provided for convenience in reading config classes from +/// files; this is not the standard way to read configuration options, but may +/// occasionally be needed. This function assumes the config has a function +/// "void Register(OptionsItf *opts)" which it can call to register the +/// ParseOptions object. +template void ReadConfigFromFile(const std::string &config_filename, + C *c) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << config_filename << "'"; + ParseOptions po(usage_str.str().c_str()); + c->Register(&po); + po.ReadConfigFile(config_filename); +} + +/// This variant of the template ReadConfigFromFile is for if you need to read +/// two config classes from the same file. +template void ReadConfigsFromFile(const std::string &conf, + C1 *c1, C2 *c2) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << conf << "'"; + ParseOptions po(usage_str.str().c_str()); + c1->Register(&po); + c2->Register(&po); + po.ReadConfigFile(conf); +} + + + +} // namespace kaldi + +#endif // KALDI_UTIL_PARSE_OPTIONS_H_ diff --git a/torchaudio/functional/__init__.py b/torchaudio/functional/__init__.py index 688708296be..c5aa6156b7c 100644 --- a/torchaudio/functional/__init__.py +++ b/torchaudio/functional/__init__.py @@ -3,6 +3,7 @@ angle, complex_norm, compute_deltas, + compute_kaldi_pitch, create_dct, create_fb_matrix, DB_to_amplitude, diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index d1c44cf94fe..55206df02f5 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -13,6 +13,7 @@ "amplitude_to_DB", "DB_to_amplitude", "compute_deltas", + "compute_kaldi_pitch", "create_fb_matrix", "create_dct", "compute_deltas", @@ -971,3 +972,37 @@ def spectral_centroid( device=specgram.device).reshape((-1, 1)) freq_dim = -2 return (freqs * specgram).sum(dim=freq_dim) / specgram.sum(dim=freq_dim) + + +def compute_kaldi_pitch( + wave: torch.Tensor, + sample_frequency: float, + frame_length: float = 25.0, + frame_shift: float = 10.0, + preemph_coeff: float = 0.0, + min_f0: float = 50, + max_f0: float = 400, + soft_min_f0: float = 10.0, + penalty_factor: float = 0.1, + lowpass_cutoff: float = 1000, + resample_frequency: float = 4000, + delta_pitch: float = 0.005, + nccf_ballast: float = 7000, + lowpass_filter_width: int = 1, + upsample_filter_width: int = 5, + max_frames_latency: int = 0, + frames_per_chunk: int = 0, + simulate_first_pass_online: bool = False, + recompute_frame: int = 500, + nccf_ballast_online: bool = False, + snip_edges: bool = True, +): + """Equivalent of `compute-kaldi-pitch-feats`""" + return torch.ops.torchaudio.kaldi_ComputeKaldiPitch( + wave, sample_frequency, frame_length, frame_shift, preemph_coeff, + min_f0, max_f0, soft_min_f0, penalty_factor, lowpass_cutoff, + resample_frequency, delta_pitch, nccf_ballast, + lowpass_filter_width, upsample_filter_width, max_frames_latency, + frames_per_chunk, simulate_first_pass_online, recompute_frame, + nccf_ballast_online, snip_edges, + )