From 049fb9f45139291d408d0c5b44058e2bb00c79ab Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 16 Nov 2023 14:20:41 +0800 Subject: [PATCH] Add Python APIs for WeNet CTC models (#428) --- .github/scripts/test-python.sh | 45 +++++++ .github/workflows/mfc.yaml | 21 +++- .github/workflows/run-python-test.yaml | 2 + CMakeLists.txt | 2 +- python-api-examples/generate-subtitles.py | 33 ++++++ python-api-examples/non_streaming_server.py | 42 ++++++- python-api-examples/offline-decode-files.py | 38 +++++- python-api-examples/online-decode-files.py | 49 ++++++++ python-api-examples/streaming_server.py | 30 ++++- .../python/sherpa_onnx/offline_recognizer.py | 71 ++++++++++- .../python/sherpa_onnx/online_recognizer.py | 111 +++++++++++++++++- .../python/tests/test_offline_recognizer.py | 47 ++++++++ .../python/tests/test_online_recognizer.py | 58 +++++++++ 13 files changed, 538 insertions(+), 11 deletions(-) diff --git a/.github/scripts/test-python.sh b/.github/scripts/test-python.sh index 6567dd59e..5491ab6fb 100755 --- a/.github/scripts/test-python.sh +++ b/.github/scripts/test-python.sh @@ -8,6 +8,51 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +wenet_models=( +sherpa-onnx-zh-wenet-aishell +sherpa-onnx-zh-wenet-aishell2 +sherpa-onnx-zh-wenet-wenetspeech +sherpa-onnx-zh-wenet-multi-cn +sherpa-onnx-en-wenet-librispeech +sherpa-onnx-en-wenet-gigaspeech +) + +mkdir -p /tmp/icefall-models +dir=/tmp/icefall-models + +for name in ${wenet_models[@]}; do + repo_url=https://huggingface.co/csukuangfj/$name + log "Start testing ${repo_url}" + repo=$dir/$(basename $repo_url) + log "Download pretrained model and test-data from $repo_url" + pushd $dir + GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url + cd $repo + git lfs pull --include "*.onnx" + ls -lh *.onnx + popd + + python3 ./python-api-examples/offline-decode-files.py \ + --tokens=$repo/tokens.txt \ + --wenet-ctc=$repo/model.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + + python3 ./python-api-examples/online-decode-files.py \ + --tokens=$repo/tokens.txt \ + --wenet-ctc=$repo/model-streaming.onnx \ + $repo/test_wavs/0.wav \ + $repo/test_wavs/1.wav \ + $repo/test_wavs/8k.wav + + python3 sherpa-onnx/python/tests/test_offline_recognizer.py --verbose + + python3 sherpa-onnx/python/tests/test_online_recognizer.py --verbose + + rm -rf $repo +done + log "Offline TTS test" # test waves are saved in ./tts mkdir ./tts diff --git a/.github/workflows/mfc.yaml b/.github/workflows/mfc.yaml index 78dbfe615..578cfec6e 100644 --- a/.github/workflows/mfc.yaml +++ b/.github/workflows/mfc.yaml @@ -85,10 +85,19 @@ jobs: arch=${{ matrix.arch }} cd mfc-examples/$arch/Release - cp StreamingSpeechRecognition.exe sherpa-onnx-streaming-${SHERPA_ONNX_VERSION}.exe - cp NonStreamingSpeechRecognition.exe sherpa-onnx-non-streaming-${SHERPA_ONNX_VERSION}.exe ls -lh + cp -v StreamingSpeechRecognition.exe sherpa-onnx-streaming-${SHERPA_ONNX_VERSION}.exe + cp -v NonStreamingSpeechRecognition.exe sherpa-onnx-non-streaming-${SHERPA_ONNX_VERSION}.exe + cp -v NonStreamingTextToSpeech.exe ../sherpa-onnx-non-streaming-tts-${SHERPA_ONNX_VERSION}.exe + ls -lh + + - name: Upload artifact tts + uses: actions/upload-artifact@v3 + with: + name: non-streaming-tts-${{ matrix.arch }} + path: ./mfc-examples/${{ matrix.arch }}/Release/NonStreamingTextToSpeech.exe + - name: Upload artifact uses: actions/upload-artifact@v3 with: @@ -116,3 +125,11 @@ jobs: file_glob: true overwrite: true file: ./mfc-examples/${{ matrix.arch }}/Release/sherpa-onnx-non-streaming-*.exe + + - name: Release pre-compiled binaries and libs for Windows ${{ matrix.arch }} + if: github.repository_owner == 'csukuangfj' || github.repository_owner == 'k2-fsa' && github.event_name == 'push' && contains(github.ref, 'refs/tags/') + uses: svenstaro/upload-release-action@v2 + with: + file_glob: true + overwrite: true + file: ./mfc-examples/${{ matrix.arch }}/sherpa-onnx-non-streaming-*.exe diff --git a/.github/workflows/run-python-test.yaml b/.github/workflows/run-python-test.yaml index 036e06dd1..348343fe9 100644 --- a/.github/workflows/run-python-test.yaml +++ b/.github/workflows/run-python-test.yaml @@ -10,6 +10,7 @@ on: - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' + - 'python-api-examples/**' pull_request: branches: - master @@ -19,6 +20,7 @@ on: - 'CMakeLists.txt' - 'cmake/**' - 'sherpa-onnx/csrc/*' + - 'python-api-examples/**' workflow_dispatch: concurrency: diff --git a/CMakeLists.txt b/CMakeLists.txt index f821c5c84..748e835b3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.8.9") +set(SHERPA_ONNX_VERSION "1.8.10") # Disable warning about # diff --git a/python-api-examples/generate-subtitles.py b/python-api-examples/generate-subtitles.py index 86a8fec2b..e9edb03c9 100755 --- a/python-api-examples/generate-subtitles.py +++ b/python-api-examples/generate-subtitles.py @@ -58,6 +58,15 @@ --num-threads=2 \ /path/to/test.mp4 +(4) For WeNet CTC models + +./python-api-examples/generate-subtitles.py \ + --silero-vad-model=/path/to/silero_vad.onnx \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + --num-threads=2 \ + /path/to/test.mp4 + Please refer to https://k2-fsa.github.io/sherpa/onnx/index.html to install sherpa-onnx and to download non-streaming pre-trained models @@ -121,6 +130,13 @@ def get_args(): help="Path to the model.onnx from Paraformer", ) + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the CTC model.onnx from WeNet", + ) + parser.add_argument( "--num-threads", type=int, @@ -215,6 +231,7 @@ def assert_file_exists(filename: str): def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: if args.encoder: assert len(args.paraformer) == 0, args.paraformer + assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder @@ -234,6 +251,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: debug=args.debug, ) elif args.paraformer: + assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder @@ -248,6 +266,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: decoding_method=args.decoding_method, debug=args.debug, ) + elif args.wenet_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.wenet_ctc) + + recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc( + model=args.wenet_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) elif args.whisper_encoder: assert_file_exists(args.whisper_encoder) assert_file_exists(args.whisper_decoder) diff --git a/python-api-examples/non_streaming_server.py b/python-api-examples/non_streaming_server.py index 902f658c4..2ca45a76e 100755 --- a/python-api-examples/non_streaming_server.py +++ b/python-api-examples/non_streaming_server.py @@ -58,7 +58,19 @@ --nemo-ctc ./sherpa-onnx-nemo-ctc-en-conformer-medium/model.onnx \ --tokens ./sherpa-onnx-nemo-ctc-en-conformer-medium/tokens.txt -(4) Use a Whisper model +(4) Use a non-streaming CTC model from WeNet + +cd /path/to/sherpa-onnx +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech +cd sherpa-onnx-zh-wenet-wenetspeech +git lfs pull --include "*.onnx" +cd .. + +python3 ./python-api-examples/non_streaming_server.py \ + --wenet-ctc ./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens ./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt + +(5) Use a Whisper model cd /path/to/sherpa-onnx GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-tiny.en @@ -210,6 +222,15 @@ def add_nemo_ctc_model_args(parser: argparse.ArgumentParser): ) +def add_wenet_ctc_model_args(parser: argparse.ArgumentParser): + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the model.onnx from WeNet CTC", + ) + + def add_tdnn_ctc_model_args(parser: argparse.ArgumentParser): parser.add_argument( "--tdnn-model", @@ -261,6 +282,7 @@ def add_model_args(parser: argparse.ArgumentParser): add_transducer_model_args(parser) add_paraformer_model_args(parser) add_nemo_ctc_model_args(parser) + add_wenet_ctc_model_args(parser) add_tdnn_ctc_model_args(parser) add_whisper_model_args(parser) @@ -804,6 +826,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: if args.encoder: assert len(args.paraformer) == 0, args.paraformer assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model @@ -827,6 +850,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: ) elif args.paraformer: assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model @@ -842,6 +866,7 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: decoding_method=args.decoding_method, ) elif args.nemo_ctc: + assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model @@ -856,6 +881,21 @@ def create_recognizer(args) -> sherpa_onnx.OfflineRecognizer: feature_dim=args.feat_dim, decoding_method=args.decoding_method, ) + elif args.wenet_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.wenet_ctc) + + recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc( + model=args.wenet_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + ) elif args.whisper_encoder: assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.whisper_encoder) diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index ad8d1ebaf..16b11360d 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -59,7 +59,16 @@ ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav -(5) For tdnn models of the yesno recipe from icefall +(5) For CTC models from WeNet + +python3 ./python-api-examples/offline-decode-files.py \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav + +(6) For tdnn models of the yesno recipe from icefall python3 ./python-api-examples/offline-decode-files.py \ --sample-rate=8000 \ @@ -154,6 +163,13 @@ def get_args(): help="Path to the model.onnx from NeMo CTC", ) + parser.add_argument( + "--wenet-ctc", + default="", + type=str, + help="Path to the model.onnx from WeNet CTC", + ) + parser.add_argument( "--tdnn-model", default="", @@ -254,6 +270,7 @@ def assert_file_exists(filename: str): "https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it" ) + def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]: """ Args: @@ -287,6 +304,7 @@ def main(): if args.encoder: assert len(args.paraformer) == 0, args.paraformer assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model @@ -310,6 +328,7 @@ def main(): ) elif args.paraformer: assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model @@ -326,6 +345,7 @@ def main(): debug=args.debug, ) elif args.nemo_ctc: + assert len(args.wenet_ctc) == 0, args.wenet_ctc assert len(args.whisper_encoder) == 0, args.whisper_encoder assert len(args.whisper_decoder) == 0, args.whisper_decoder assert len(args.tdnn_model) == 0, args.tdnn_model @@ -341,6 +361,22 @@ def main(): decoding_method=args.decoding_method, debug=args.debug, ) + elif args.wenet_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert len(args.tdnn_model) == 0, args.tdnn_model + + assert_file_exists(args.wenet_ctc) + + recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc( + model=args.wenet_ctc, + tokens=args.tokens, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feature_dim, + decoding_method=args.decoding_method, + debug=args.debug, + ) elif args.whisper_encoder: assert len(args.tdnn_model) == 0, args.tdnn_model assert_file_exists(args.whisper_encoder) diff --git a/python-api-examples/online-decode-files.py b/python-api-examples/online-decode-files.py index c6606f94b..56f9dc525 100755 --- a/python-api-examples/online-decode-files.py +++ b/python-api-examples/online-decode-files.py @@ -37,8 +37,25 @@ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/3.wav \ ./sherpa-onnx-streaming-paraformer-bilingual-zh-en/test_wavs/8k.wav +(3) Streaming Conformer CTC from WeNet + +GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-zh-wenet-wenetspeech +cd sherpa-onnx-zh-wenet-wenetspeech +git lfs pull --include "*.onnx" + +./python-api-examples/online-decode-files.py \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \ + ./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav + + + Please refer to https://k2-fsa.github.io/sherpa/onnx/index.html +and +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html to install sherpa-onnx and to download streaming pre-trained models. """ import argparse @@ -92,6 +109,26 @@ def get_args(): help="Path to the paraformer decoder model", ) + parser.add_argument( + "--wenet-ctc", + type=str, + help="Path to the wenet ctc model model", + ) + + parser.add_argument( + "--wenet-ctc-chunk-size", + type=int, + default=16, + help="The --chunk-size parameter for streaming WeNet models", + ) + + parser.add_argument( + "--wenet-ctc-num-left-chunks", + type=int, + default=4, + help="The --num-left-chunks parameter for streaming WeNet models", + ) + parser.add_argument( "--num-threads", type=int, @@ -249,6 +286,18 @@ def main(): feature_dim=80, decoding_method="greedy_search", ) + elif args.wenet_ctc: + recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc( + tokens=args.tokens, + model=args.wenet_ctc, + chunk_size=args.wenet_ctc_chunk_size, + num_left_chunks=args.wenet_ctc_num_left_chunks, + num_threads=args.num_threads, + provider=args.provider, + sample_rate=16000, + feature_dim=80, + decoding_method="greedy_search", + ) else: raise ValueError("Please provide a model") diff --git a/python-api-examples/streaming_server.py b/python-api-examples/streaming_server.py index b5a37a40e..c40428a83 100755 --- a/python-api-examples/streaming_server.py +++ b/python-api-examples/streaming_server.py @@ -40,10 +40,17 @@ Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/index.html +https://k2-fsa.github.io/sherpa/onnx/pretrained_models/wenet/index.html to download pre-trained models. The model in the above help messages is from https://k2-fsa.github.io/sherpa/onnx/pretrained_models/online-transducer/zipformer-transducer-models.html#csukuangfj-sherpa-onnx-streaming-zipformer-bilingual-zh-en-2023-02-20-bilingual-chinese-english + +To use a WeNet streaming Conformer CTC model, please use + +python3 ./python-api-examples/streaming_server.py \ + --tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \ + --wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model-streaming.onnx """ import argparse @@ -130,6 +137,12 @@ def add_model_args(parser: argparse.ArgumentParser): help="Path to the transducer joiner model.", ) + parser.add_argument( + "--wenet-ctc", + type=str, + help="Path to the model.onnx from WeNet", + ) + parser.add_argument( "--paraformer-encoder", type=str, @@ -212,7 +225,6 @@ def add_hotwords_args(parser: argparse.ArgumentParser): ) - def add_modified_beam_search_args(parser: argparse.ArgumentParser): parser.add_argument( "--num-active-paths", @@ -393,6 +405,20 @@ def create_recognizer(args) -> sherpa_onnx.OnlineRecognizer: rule3_min_utterance_length=args.rule3_min_utterance_length, provider=args.provider, ) + elif args.wenet_ctc: + recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc( + tokens=args.tokens, + model=args.wenet_ctc, + num_threads=args.num_threads, + sample_rate=args.sample_rate, + feature_dim=args.feat_dim, + decoding_method=args.decoding_method, + enable_endpoint_detection=args.use_endpoint != 0, + rule1_min_trailing_silence=args.rule1_min_trailing_silence, + rule2_min_trailing_silence=args.rule2_min_trailing_silence, + rule3_min_utterance_length=args.rule3_min_utterance_length, + provider=args.provider, + ) else: raise ValueError("Please provide a model") @@ -727,6 +753,8 @@ def check_args(args): assert Path( args.paraformer_decoder ).is_file(), f"{args.paraformer_decoder} does not exist" + elif args.wenet_ctc: + assert Path(args.wenet_ctc).is_file(), f"{args.wenet_ctc} does not exist" else: raise ValueError("Please provide a model") diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 21bd8d588..57afb2dca 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -9,15 +9,16 @@ OfflineModelConfig, OfflineNemoEncDecCtcModelConfig, OfflineParaformerModelConfig, - OfflineTdnnModelConfig, - OfflineWhisperModelConfig, - OfflineZipformerCtcModelConfig, ) from _sherpa_onnx import OfflineRecognizer as _Recognizer from _sherpa_onnx import ( OfflineRecognizerConfig, OfflineStream, + OfflineTdnnModelConfig, OfflineTransducerModelConfig, + OfflineWenetCtcModelConfig, + OfflineWhisperModelConfig, + OfflineZipformerCtcModelConfig, ) @@ -389,6 +390,70 @@ def from_tdnn_ctc( self.config = recognizer_config return self + @classmethod + def from_wenet_ctc( + cls, + model: str, + tokens: str, + num_threads: int = 1, + sample_rate: int = 16000, + feature_dim: int = 80, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + model: + Path to ``model.onnx``. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + decoding_method: + Valid values are greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + wenet_ctc=OfflineWenetCtcModelConfig(model=model), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + model_type="wenet_ctc", + ) + + feat_config = OfflineFeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + ) + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + def create_stream(self, hotwords: Optional[str] = None): if hotwords is None: return self.recognizer.create_stream() diff --git a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py index c547c3166..0198ffb29 100644 --- a/sherpa-onnx/python/sherpa_onnx/online_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/online_recognizer.py @@ -12,6 +12,7 @@ OnlineRecognizerConfig, OnlineStream, OnlineTransducerModelConfig, + OnlineWenetCtcModelConfig, ) @@ -140,13 +141,13 @@ def from_transducer( "Please use --decoding-method=modified_beam_search when using " f"--hotwords-file. Currently given: {decoding_method}" ) - + if lm and decoding_method != "modified_beam_search": raise ValueError( "Please use --decoding-method=modified_beam_search when using " f"--lm. Currently given: {decoding_method}" ) - + lm_config = OnlineLMConfig( model=lm, scale=lm_scale, @@ -271,6 +272,112 @@ def from_paraformer( self.config = recognizer_config return self + @classmethod + def from_wenet_ctc( + cls, + tokens: str, + model: str, + chunk_size: int = 16, + num_left_chunks: int = 4, + num_threads: int = 2, + sample_rate: float = 16000, + feature_dim: int = 80, + enable_endpoint_detection: bool = False, + rule1_min_trailing_silence: float = 2.4, + rule2_min_trailing_silence: float = 1.2, + rule3_min_utterance_length: float = 20.0, + decoding_method: str = "greedy_search", + provider: str = "cpu", + ): + """ + Please refer to + ``_ + to download pre-trained models for different languages, e.g., Chinese, + English, etc. + + Args: + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + model: + Path to ``model.onnx``. + chunk_size: + The --chunk-size parameter from WeNet. + num_left_chunks: + The --num-left-chunks parameter from WeNet. + num_threads: + Number of threads for neural network computation. + sample_rate: + Sample rate of the training data used to train the model. + feature_dim: + Dimension of the feature used to train the model. + enable_endpoint_detection: + True to enable endpoint detection. False to disable endpoint + detection. + rule1_min_trailing_silence: + Used only when enable_endpoint_detection is True. If the duration + of trailing silence in seconds is larger than this value, we assume + an endpoint is detected. + rule2_min_trailing_silence: + Used only when enable_endpoint_detection is True. If we have decoded + something that is nonsilence and if the duration of trailing silence + in seconds is larger than this value, we assume an endpoint is + detected. + rule3_min_utterance_length: + Used only when enable_endpoint_detection is True. If the utterance + length in seconds is larger than this value, we assume an endpoint + is detected. + decoding_method: + The only valid value is greedy_search. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + """ + self = cls.__new__(cls) + _assert_file_exists(tokens) + _assert_file_exists(model) + + assert num_threads > 0, num_threads + + wenet_ctc_config = OnlineWenetCtcModelConfig( + model=model, + chunk_size=chunk_size, + num_left_chunks=num_left_chunks, + ) + + model_config = OnlineModelConfig( + wenet_ctc=wenet_ctc_config, + tokens=tokens, + num_threads=num_threads, + provider=provider, + model_type="wenet_ctc", + ) + + feat_config = FeatureExtractorConfig( + sampling_rate=sample_rate, + feature_dim=feature_dim, + ) + + endpoint_config = EndpointConfig( + rule1_min_trailing_silence=rule1_min_trailing_silence, + rule2_min_trailing_silence=rule2_min_trailing_silence, + rule3_min_utterance_length=rule3_min_utterance_length, + ) + + recognizer_config = OnlineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + endpoint_config=endpoint_config, + enable_endpoint=enable_endpoint_detection, + decoding_method=decoding_method, + ) + + self.recognizer = _Recognizer(recognizer_config) + self.config = recognizer_config + return self + def create_stream(self, hotwords: Optional[str] = None): if hotwords is None: return self.recognizer.create_stream() diff --git a/sherpa-onnx/python/tests/test_offline_recognizer.py b/sherpa-onnx/python/tests/test_offline_recognizer.py index f6d36a536..68cc70f09 100755 --- a/sherpa-onnx/python/tests/test_offline_recognizer.py +++ b/sherpa-onnx/python/tests/test_offline_recognizer.py @@ -267,6 +267,53 @@ def test_nemo_ctc_multiple_files(self): print(s1.result.text) print(s2.result.text) + def test_wenet_ctc(self): + models = [ + "sherpa-onnx-zh-wenet-aishell", + "sherpa-onnx-zh-wenet-aishell2", + "sherpa-onnx-zh-wenet-wenetspeech", + "sherpa-onnx-zh-wenet-multi-cn", + "sherpa-onnx-en-wenet-librispeech", + "sherpa-onnx-en-wenet-gigaspeech", + ] + for m in models: + for use_int8 in [True, False]: + name = "model.int8.onnx" if use_int8 else "model.onnx" + model = f"{d}/{m}/{name}" + tokens = f"{d}/{m}/tokens.txt" + + wave0 = f"{d}/{m}/test_wavs/0.wav" + wave1 = f"{d}/{m}/test_wavs/1.wav" + wave2 = f"{d}/{m}/test_wavs/8k.wav" + + if not Path(model).is_file(): + print("skipping test_wenet_ctc()") + return + + recognizer = sherpa_onnx.OfflineRecognizer.from_wenet_ctc( + model=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + s0 = recognizer.create_stream() + samples0, sample_rate0 = read_wave(wave0) + s0.accept_waveform(sample_rate0, samples0) + + s1 = recognizer.create_stream() + samples1, sample_rate1 = read_wave(wave1) + s1.accept_waveform(sample_rate1, samples1) + + s2 = recognizer.create_stream() + samples2, sample_rate2 = read_wave(wave2) + s2.accept_waveform(sample_rate2, samples2) + + recognizer.decode_streams([s0, s1, s2]) + print(s0.result.text) + print(s1.result.text) + print(s2.result.text) + if __name__ == "__main__": unittest.main() diff --git a/sherpa-onnx/python/tests/test_online_recognizer.py b/sherpa-onnx/python/tests/test_online_recognizer.py index f5c15e5c2..7df00fe09 100755 --- a/sherpa-onnx/python/tests/test_online_recognizer.py +++ b/sherpa-onnx/python/tests/test_online_recognizer.py @@ -143,6 +143,64 @@ def test_transducer_multiple_files(self): print(f"{wave_filename}\n{result}") print("-" * 10) + def test_wenet_ctc(self): + models = [ + "sherpa-onnx-zh-wenet-aishell", + "sherpa-onnx-zh-wenet-aishell2", + "sherpa-onnx-zh-wenet-wenetspeech", + "sherpa-onnx-zh-wenet-multi-cn", + "sherpa-onnx-en-wenet-librispeech", + "sherpa-onnx-en-wenet-gigaspeech", + ] + for m in models: + for use_int8 in [True, False]: + name = ( + "model-streaming.int8.onnx" if use_int8 else "model-streaming.onnx" + ) + model = f"{d}/{m}/{name}" + tokens = f"{d}/{m}/tokens.txt" + + wave0 = f"{d}/{m}/test_wavs/0.wav" + wave1 = f"{d}/{m}/test_wavs/1.wav" + wave2 = f"{d}/{m}/test_wavs/8k.wav" + + if not Path(model).is_file(): + print("skipping test_wenet_ctc()") + return + + recognizer = sherpa_onnx.OnlineRecognizer.from_wenet_ctc( + model=model, + tokens=tokens, + num_threads=1, + provider="cpu", + ) + + streams = [] + waves = [wave0, wave1, wave2] + for wave in waves: + s = recognizer.create_stream() + samples, sample_rate = read_wave(wave) + s.accept_waveform(sample_rate, samples) + + tail_paddings = np.zeros(int(0.2 * sample_rate), dtype=np.float32) + s.accept_waveform(sample_rate, tail_paddings) + s.input_finished() + streams.append(s) + + while True: + ready_list = [] + for s in streams: + if recognizer.is_ready(s): + ready_list.append(s) + if len(ready_list) == 0: + break + recognizer.decode_streams(ready_list) + + results = [recognizer.get_result(s) for s in streams] + for wave_filename, result in zip(waves, results): + print(f"{wave_filename}\n{result}") + print("-" * 10) + if __name__ == "__main__": unittest.main()