-
Notifications
You must be signed in to change notification settings - Fork 661
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
1,025 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
{"sample_rate": 8000} | ||
{"sample_rate": 8000, "frames_per_chunk": 200} | ||
{"sample_rate": 8000, "frames_per_chunk": 200, "simulate_first_pass_online": true} | ||
{"sample_rate": 16000} | ||
{"sample_rate": 44100} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import subprocess | ||
|
||
import torch | ||
|
||
|
||
def convert_args(**kwargs): | ||
args = [] | ||
for key, value in kwargs.items(): | ||
if key == 'sample_rate': | ||
key = 'sample_frequency' | ||
key = '--' + key.replace('_', '-') | ||
value = str(value).lower() if value in [True, False] else str(value) | ||
args.append('%s=%s' % (key, value)) | ||
return args | ||
|
||
|
||
def run_kaldi(command, input_type, input_value): | ||
"""Run provided Kaldi command, pass a tensor and get the resulting tensor | ||
Args: | ||
input_type: str | ||
'ark' or 'scp' | ||
input_value: | ||
Tensor for 'ark' | ||
string for 'scp' (path to an audio file) | ||
""" | ||
import kaldi_io | ||
|
||
key = 'foo' | ||
process = subprocess.Popen(command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) | ||
if input_type == 'ark': | ||
kaldi_io.write_mat(process.stdin, input_value.cpu().numpy(), key=key) | ||
elif input_type == 'scp': | ||
process.stdin.write(f'{key} {input_value}'.encode('utf8')) | ||
else: | ||
raise NotImplementedError('Unexpected type') | ||
process.stdin.close() | ||
result = dict(kaldi_io.read_mat_ark(process.stdout))['foo'] | ||
return torch.from_numpy(result.copy()) # copy supresses some torch warning |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
9 changes: 9 additions & 0 deletions
9
test/torchaudio_unittest/functional/kaldi_compatibility_cpu_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import torch | ||
|
||
from torchaudio_unittest.common_utils import PytorchTestCase | ||
from .kaldi_compatibility_test_impl import KaldiCPUOnly | ||
|
||
|
||
class TestKaldiCPUOnly(KaldiCPUOnly, PytorchTestCase): | ||
dtype = torch.float32 | ||
device = torch.device('cpu') |
37 changes: 37 additions & 0 deletions
37
test/torchaudio_unittest/functional/kaldi_compatibility_test_impl.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from parameterized import parameterized | ||
import torchaudio.functional as F | ||
|
||
from torchaudio_unittest.common_utils import ( | ||
get_sinusoid, | ||
load_params, | ||
save_wav, | ||
skipIfNoExec, | ||
TempDirMixin, | ||
TestBaseMixin, | ||
) | ||
from torchaudio_unittest.common_utils.kaldi_utils import ( | ||
convert_args, | ||
run_kaldi, | ||
) | ||
|
||
|
||
class KaldiCPUOnly(TempDirMixin, TestBaseMixin): | ||
def assert_equal(self, output, *, expected, rtol=None, atol=None): | ||
expected = expected.to(dtype=self.dtype, device=self.device) | ||
self.assertEqual(output, expected, rtol=rtol, atol=atol) | ||
|
||
@parameterized.expand(load_params('kaldi_test_pitch_args.json')) | ||
@skipIfNoExec('compute-kaldi-pitch-feats') | ||
def test_pitch_feats(self, kwargs): | ||
"""compute_kaldi_pitch produces numerically compatible result with compute-kaldi-pitch-feats""" | ||
sample_rate = kwargs['sample_rate'] | ||
waveform = get_sinusoid(dtype='float32', sample_rate=sample_rate) | ||
result = F.compute_kaldi_pitch(waveform[0], **kwargs) | ||
|
||
waveform = get_sinusoid(dtype='int16', sample_rate=sample_rate) | ||
wave_file = self.get_temp_path('test.wav') | ||
save_wav(wave_file, waveform, sample_rate) | ||
|
||
command = ['compute-kaldi-pitch-feats'] + convert_args(**kwargs) + ['scp:-', 'ark:-'] | ||
kaldi_result = run_kaldi(command, 'scp', wave_file) | ||
self.assert_equal(result, expected=kaldi_result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
set(KALDI_REPO ${CMAKE_CURRENT_SOURCE_DIR}/submodule) | ||
|
||
# Apply custom patch | ||
execute_process( | ||
WORKING_DIRECTORY ${KALDI_REPO} | ||
COMMAND "git" "checkout" "." | ||
) | ||
execute_process( | ||
WORKING_DIRECTORY ${KALDI_REPO} | ||
COMMAND git apply ../kaldi.patch | ||
) | ||
# Update the version string | ||
execute_process( | ||
WORKING_DIRECTORY ${KALDI_REPO}/src/base | ||
COMMAND sh get_version.sh | ||
) | ||
|
||
set(KALDI_SOURCES | ||
src/matrix/kaldi-vector.cc | ||
src/matrix/kaldi-matrix.cc | ||
submodule/src/base/kaldi-error.cc | ||
submodule/src/base/kaldi-math.cc | ||
submodule/src/feat/feature-functions.cc | ||
submodule/src/feat/pitch-functions.cc | ||
submodule/src/feat/resample.cc | ||
) | ||
|
||
add_library(kaldi STATIC ${KALDI_SOURCES}) | ||
target_include_directories(kaldi PUBLIC src submodule/src) | ||
target_link_libraries(kaldi ${TORCH_LIBRARIES}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Custom Kaldi build | ||
|
||
This directory contains original Kaldi repository (as submodule), [the custom implementation of Kaldi's vector/matrix](./src) and the build script. | ||
|
||
We use the custom build process so that the resulting library only contains what torchaudio needs. | ||
We use the custom vector/matrix implementation so that we can use the same BLAS library that PyTorch is compiled with, and so that we can (hopefully, in future) take advantage of other PyTorch features (such as differentiability and GPU support). The down side of this approach is that it adds a lot of overhead compared to the original Kaldi (operator dispatch and element-wise processing, which PyTorch is not efficient at). We can improve this gradually, and if you are interested in helping, please let us know by opening an issue. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
diff --git a/src/base/kaldi-types.h b/src/base/kaldi-types.h | ||
index 7ebf4f853..c15b288b2 100644 | ||
--- a/src/base/kaldi-types.h | ||
+++ b/src/base/kaldi-types.h | ||
@@ -41,6 +41,7 @@ typedef float BaseFloat; | ||
|
||
// for discussion on what to do if you need compile kaldi | ||
// without OpenFST, see the bottom of this this file | ||
+/* | ||
#include <fst/types.h> | ||
|
||
namespace kaldi { | ||
@@ -53,10 +54,10 @@ namespace kaldi { | ||
typedef float float32; | ||
typedef double double64; | ||
} // end namespace kaldi | ||
+*/ | ||
|
||
// In a theoretical case you decide compile Kaldi without the OpenFST | ||
// comment the previous namespace statement and uncomment the following | ||
-/* | ||
namespace kaldi { | ||
typedef int8_t int8; | ||
typedef int16_t int16; | ||
@@ -70,6 +71,5 @@ namespace kaldi { | ||
typedef float float32; | ||
typedef double double64; | ||
} // end namespace kaldi | ||
-*/ | ||
|
||
#endif // KALDI_BASE_KALDI_TYPES_H_ | ||
diff --git a/src/matrix/matrix-lib.h b/src/matrix/matrix-lib.h | ||
index b6059b06c..4fb9e1b16 100644 | ||
--- a/src/matrix/matrix-lib.h | ||
+++ b/src/matrix/matrix-lib.h | ||
@@ -25,14 +25,14 @@ | ||
#include "base/kaldi-common.h" | ||
#include "matrix/kaldi-vector.h" | ||
#include "matrix/kaldi-matrix.h" | ||
-#include "matrix/sp-matrix.h" | ||
-#include "matrix/tp-matrix.h" | ||
+// #include "matrix/sp-matrix.h" | ||
+// #include "matrix/tp-matrix.h" | ||
#include "matrix/matrix-functions.h" | ||
#include "matrix/srfft.h" | ||
#include "matrix/compressed-matrix.h" | ||
-#include "matrix/sparse-matrix.h" | ||
+// #include "matrix/sparse-matrix.h" | ||
#include "matrix/optimization.h" | ||
-#include "matrix/numpy-array.h" | ||
+// #include "matrix/numpy-array.h" | ||
|
||
#endif | ||
|
||
diff --git a/src/util/common-utils.h b/src/util/common-utils.h | ||
index cfb0c255c..48d199e97 100644 | ||
--- a/src/util/common-utils.h | ||
+++ b/src/util/common-utils.h | ||
@@ -21,11 +21,11 @@ | ||
|
||
#include "base/kaldi-common.h" | ||
#include "util/parse-options.h" | ||
-#include "util/kaldi-io.h" | ||
-#include "util/simple-io-funcs.h" | ||
-#include "util/kaldi-holder.h" | ||
-#include "util/kaldi-table.h" | ||
-#include "util/table-types.h" | ||
-#include "util/text-utils.h" | ||
+// #include "util/kaldi-io.h" | ||
+// #include "util/simple-io-funcs.h" | ||
+// #include "util/kaldi-holder.h" | ||
+// #include "util/kaldi-table.h" | ||
+// #include "util/table-types.h" | ||
+// #include "util/text-utils.h" | ||
|
||
#endif // KALDI_UTIL_COMMON_UTILS_H_ |
Oops, something went wrong.