From 45b9d4ab37359432eea3436d12e6d2155fb7c330 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 7 Aug 2023 12:34:18 +0800 Subject: [PATCH] Support whisper models (#238) --- .github/workflows/export-whisper-to-onnx.yaml | 63 +++ .github/workflows/run-java-test.yaml | 4 +- CMakeLists.txt | 2 +- cmake/kaldi-native-fbank.cmake | 16 +- python-api-examples/offline-decode-files.py | 58 +++ scripts/whisper/.gitignore | 4 + scripts/whisper/README.md | 9 + scripts/whisper/export-onnx.py | 439 ++++++++++++++++++ scripts/whisper/requirements.txt | 1 + scripts/whisper/test.py | 241 ++++++++++ sherpa-onnx/csrc/CMakeLists.txt | 6 +- sherpa-onnx/csrc/base64-decode.cc | 67 +++ sherpa-onnx/csrc/base64-decode.h | 19 + sherpa-onnx/csrc/macros.h | 1 - sherpa-onnx/csrc/offline-model-config.cc | 8 +- sherpa-onnx/csrc/offline-model-config.h | 4 + .../offline-nemo-enc-dec-ctc-model-config.cc | 2 +- .../csrc/offline-paraformer-model-config.cc | 2 +- sherpa-onnx/csrc/offline-recognizer-impl.cc | 14 +- .../csrc/offline-recognizer-whisper-impl.h | 152 ++++++ sherpa-onnx/csrc/offline-stream.cc | 46 +- sherpa-onnx/csrc/offline-stream.h | 5 + .../csrc/offline-transducer-model-config.cc | 9 +- sherpa-onnx/csrc/offline-whisper-decoder.h | 38 ++ .../offline-whisper-greedy-search-decoder.cc | 93 ++++ .../offline-whisper-greedy-search-decoder.h | 29 ++ .../csrc/offline-whisper-model-config.cc | 46 ++ .../csrc/offline-whisper-model-config.h | 30 ++ sherpa-onnx/csrc/offline-whisper-model.cc | 213 +++++++++ sherpa-onnx/csrc/offline-whisper-model.h | 85 ++++ .../csrc/sherpa-onnx-microphone-offline.cc | 12 +- sherpa-onnx/csrc/sherpa-onnx-offline.cc | 17 +- sherpa-onnx/csrc/symbol-table.cc | 9 + sherpa-onnx/csrc/symbol-table.h | 3 + sherpa-onnx/python/csrc/CMakeLists.txt | 1 + .../python/csrc/offline-model-config.cc | 24 +- .../csrc/offline-whisper-model-config.cc | 24 + .../csrc/offline-whisper-model-config.h | 16 + .../python/sherpa_onnx/offline_recognizer.py | 74 ++- 39 files changed, 1835 insertions(+), 51 deletions(-) create mode 100644 .github/workflows/export-whisper-to-onnx.yaml create mode 100644 scripts/whisper/.gitignore create mode 100644 scripts/whisper/README.md create mode 100755 scripts/whisper/export-onnx.py create mode 100644 scripts/whisper/requirements.txt create mode 100755 scripts/whisper/test.py create mode 100644 sherpa-onnx/csrc/base64-decode.cc create mode 100644 sherpa-onnx/csrc/base64-decode.h create mode 100644 sherpa-onnx/csrc/offline-recognizer-whisper-impl.h create mode 100644 sherpa-onnx/csrc/offline-whisper-decoder.h create mode 100644 sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc create mode 100644 sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h create mode 100644 sherpa-onnx/csrc/offline-whisper-model-config.cc create mode 100644 sherpa-onnx/csrc/offline-whisper-model-config.h create mode 100644 sherpa-onnx/csrc/offline-whisper-model.cc create mode 100644 sherpa-onnx/csrc/offline-whisper-model.h create mode 100644 sherpa-onnx/python/csrc/offline-whisper-model-config.cc create mode 100644 sherpa-onnx/python/csrc/offline-whisper-model-config.h diff --git a/.github/workflows/export-whisper-to-onnx.yaml b/.github/workflows/export-whisper-to-onnx.yaml new file mode 100644 index 000000000..476bff521 --- /dev/null +++ b/.github/workflows/export-whisper-to-onnx.yaml @@ -0,0 +1,63 @@ +name: export-whisper-to-onnx + +on: + workflow_dispatch: + +concurrency: + group: release-whisper-${{ github.ref }} + cancel-in-progress: true + +jobs: + release-whisper-models: + if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' + name: ${{ matrix.model }} + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [macos-latest] + model: ["tiny.en", "base.en", "small.en", "medium.en"] + + steps: + - uses: actions/checkout@v2 + + - name: Install dependencies + shell: bash + run: | + python3 -m pip install openai-whisper torch onnxruntime onnx + + - name: export ${{ matrix.model }} + shell: bash + run: | + cd scripts/whisper + python3 ./export-onnx.py --model ${{ matrix.model }} + python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./ + + ls -lh + + ls -lh ~/.cache/whisper + + - name: Publish ${{ matrix.model }} to huggingface + shell: bash + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + run: | + cd scripts/whisper + + git config --global user.email "csukuangfj@gmail.com" + git config --global user.name "Fangjun Kuang" + + GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface + + cp *.onnx ./huggingface + cp *.ort ./huggingface + cp *tokens.txt ./huggingface + + cd huggingface + git status + ls -lh + git lfs track "*.onnx" + git lfs track "*.ort" + git add . + git commit -m "upload ${{ matrix.model }}" + git push https://csukuangfj:$HF_TOKEN@huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main diff --git a/.github/workflows/run-java-test.yaml b/.github/workflows/run-java-test.yaml index 98a850719..b70421f16 100644 --- a/.github/workflows/run-java-test.yaml +++ b/.github/workflows/run-java-test.yaml @@ -23,14 +23,14 @@ on: - 'sherpa-onnx/jni/*' concurrency: - group: jni-${{ github.ref }} + group: run-java-test-${{ github.ref }} cancel-in-progress: true permissions: contents: read jobs: - jni: + run_java_test: runs-on: ${{ matrix.os }} strategy: fail-fast: false diff --git a/CMakeLists.txt b/CMakeLists.txt index 69f29e2a0..ef9ce213d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ cmake_minimum_required(VERSION 3.13 FATAL_ERROR) project(sherpa-onnx) -set(SHERPA_ONNX_VERSION "1.5.5") +set(SHERPA_ONNX_VERSION "1.6.0") # Disable warning about # diff --git a/cmake/kaldi-native-fbank.cmake b/cmake/kaldi-native-fbank.cmake index 91c554666..d561ce882 100644 --- a/cmake/kaldi-native-fbank.cmake +++ b/cmake/kaldi-native-fbank.cmake @@ -1,9 +1,9 @@ function(download_kaldi_native_fbank) include(FetchContent) - set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.17.tar.gz") - set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.17.tar.gz") - set(kaldi_native_fbank_HASH "SHA256=300dc282d51d738e70f194ef13a50bf4cf8d54a3b2686d75f7fc2fb821f8c1e6") + set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.1.tar.gz") + set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.1.tar.gz") + set(kaldi_native_fbank_HASH "SHA256=c7676f319fa97e8c8bca6018792de120895dcfe122fa9b4bff00f8f9165348e7") set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE) set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE) @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank) # If you don't have access to the Internet, # please pre-download kaldi-native-fbank set(possible_file_locations - $ENV{HOME}/Downloads/kaldi-native-fbank-1.17.tar.gz - ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.17.tar.gz - ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.17.tar.gz - /tmp/kaldi-native-fbank-1.17.tar.gz - /star-fj/fangjun/download/github/kaldi-native-fbank-1.17.tar.gz + $ENV{HOME}/Downloads/kaldi-native-fbank-1.18.1.tar.gz + ${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.1.tar.gz + ${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.1.tar.gz + /tmp/kaldi-native-fbank-1.18.1.tar.gz + /star-fj/fangjun/download/github/kaldi-native-fbank-1.18.1.tar.gz ) foreach(f IN LISTS possible_file_locations) diff --git a/python-api-examples/offline-decode-files.py b/python-api-examples/offline-decode-files.py index 98ead3f9c..c6b63ee0d 100755 --- a/python-api-examples/offline-decode-files.py +++ b/python-api-examples/offline-decode-files.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # # Copyright (c) 2023 by manyeyes +# Copyright (c) 2023 Xiaomi Corporation """ This file demonstrates how to use sherpa-onnx Python API to transcribe @@ -34,6 +35,27 @@ (3) For CTC models from NeMo +python3 ./python-api-examples/offline-decode-files.py \ + --tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \ + --nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \ + --num-threads=2 \ + --decoding-method=greedy_search \ + --debug=false \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \ + ./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav + +(4) For Whisper models + +python3 ./python-api-examples/offline-decode-files.py \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 \ + ./sherpa-onnx-whisper-base.en/test_wavs/0.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/1.wav \ + ./sherpa-onnx-whisper-base.en/test_wavs/8k.wav + Please refer to https://k2-fsa.github.io/sherpa/onnx/index.html to install sherpa-onnx and to download the pre-trained models @@ -144,6 +166,20 @@ def get_args(): help="Number of threads for neural network computation", ) + parser.add_argument( + "--whisper-encoder", + default="", + type=str, + help="Path to whisper encoder model", + ) + + parser.add_argument( + "--whisper-decoder", + default="", + type=str, + help="Path to whisper decoder model", + ) + parser.add_argument( "--decoding-method", type=str, @@ -247,6 +283,8 @@ def main(): if args.encoder: assert len(args.paraformer) == 0, args.paraformer assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()] if contexts: @@ -271,6 +309,9 @@ def main(): ) elif args.paraformer: assert len(args.nemo_ctc) == 0, args.nemo_ctc + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + assert_file_exists(args.paraformer) recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer( @@ -283,6 +324,11 @@ def main(): debug=args.debug, ) elif args.nemo_ctc: + assert len(args.whisper_encoder) == 0, args.whisper_encoder + assert len(args.whisper_decoder) == 0, args.whisper_decoder + + assert_file_exists(args.nemo_ctc) + recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc( model=args.nemo_ctc, tokens=args.tokens, @@ -292,6 +338,18 @@ def main(): decoding_method=args.decoding_method, debug=args.debug, ) + elif args.whisper_encoder: + assert_file_exists(args.whisper_encoder) + assert_file_exists(args.whisper_decoder) + + recognizer = sherpa_onnx.OfflineRecognizer.from_whisper( + encoder=args.whisper_encoder, + decoder=args.whisper_decoder, + tokens=args.tokens, + num_threads=args.num_threads, + decoding_method=args.decoding_method, + debug=args.debug, + ) else: print("Please specify at least one model") return diff --git a/scripts/whisper/.gitignore b/scripts/whisper/.gitignore new file mode 100644 index 000000000..fbe9a87e7 --- /dev/null +++ b/scripts/whisper/.gitignore @@ -0,0 +1,4 @@ +*.onnx +*.config +*.ort +*-tokens.txt diff --git a/scripts/whisper/README.md b/scripts/whisper/README.md new file mode 100644 index 000000000..eda441486 --- /dev/null +++ b/scripts/whisper/README.md @@ -0,0 +1,9 @@ +# Introduction + +This folder contains code showing how to convert [Whisper][whisper] to onnx +and use onnxruntime to replace PyTorch for speech recognition. + +You can use [sherpa-onnx][sherpa-onnx] to run the converted model. + +[whisper]: https://github.com/openai/whisper +[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx diff --git a/scripts/whisper/export-onnx.py b/scripts/whisper/export-onnx.py new file mode 100755 index 000000000..1cdbaaf0e --- /dev/null +++ b/scripts/whisper/export-onnx.py @@ -0,0 +1,439 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +# flake8: noqa + +""" +Note: Code in this file is modified from +https://github.com/TadaoYamaoka/whisper/blob/main/to_onnx.py + +Thanks to https://github.com/TadaoYamaoka +for making the onnx export script public. +""" + +import argparse +from pathlib import Path +from typing import Any, Dict, Optional + +import onnx +import torch +from onnxruntime.quantization import QuantType, quantize_dynamic +from torch import Tensor, nn + +import whisper +from whisper.model import ( + AudioEncoder, + MultiHeadAttention, + ResidualAttentionBlock, + TextDecoder, +) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + required=True, + # fmt: off + choices=[ + "tiny", "tiny.en", "base", "base.en", + "small", "small.en", "medium", "medium.en", + "large", "large-v1", "large-v2"], + # fmt: on + ) + return parser.parse_args() + + +def add_meta_data(filename: str, meta_data: Dict[str, Any]): + """Add meta data to an ONNX model. It is changed in-place. + + Args: + filename: + Filename of the ONNX model to be changed. + meta_data: + Key-value pairs. + """ + model = onnx.load(filename) + for key, value in meta_data.items(): + meta = model.metadata_props.add() + meta.key = key + meta.value = str(value) + + onnx.save(model, filename) + + +class AudioEncoderTensorCache(nn.Module): + def __init__(self, inAudioEncoder: AudioEncoder, inTextDecoder: TextDecoder): + super().__init__() + self.audioEncoder = inAudioEncoder + self.textDecoder = inTextDecoder + + def forward(self, x: Tensor): + audio_features = self.audioEncoder(x) + + n_layer_cross_k_list = [] + n_layer_cross_v_list = [] + for block in self.textDecoder.blocks: + n_layer_cross_k_list.append(block.cross_attn.key(audio_features)) + n_layer_cross_v_list.append(block.cross_attn.value(audio_features)) + + return torch.stack(n_layer_cross_k_list), torch.stack(n_layer_cross_v_list) + + +class MultiHeadAttentionCross(nn.Module): + def __init__(self, inMultiHeadAttention: MultiHeadAttention): + super().__init__() + self.multiHeadAttention = inMultiHeadAttention + + def forward( + self, + x: Tensor, + k: Tensor, + v: Tensor, + mask: Optional[Tensor] = None, + ): + q = self.multiHeadAttention.query(x) + wv, qk = self.multiHeadAttention.qkv_attention(q, k, v, mask) + return self.multiHeadAttention.out(wv) + + +class MultiHeadAttentionSelf(nn.Module): + def __init__(self, inMultiHeadAttention: MultiHeadAttention): + super().__init__() + self.multiHeadAttention = inMultiHeadAttention + + def forward( + self, + x: Tensor, # (b, n_ctx , n_state) + k_cache: Tensor, # (b, n_ctx_cache, n_state) + v_cache: Tensor, # (b, n_ctx_cache, n_state) + mask: Tensor, + ): + q = self.multiHeadAttention.query(x) # (b, n_ctx, n_state) + k = self.multiHeadAttention.key(x) # (b, n_ctx, n_state) + v = self.multiHeadAttention.value(x) # (b, n_ctx, n_state) + + k_cache[:, -k.shape[1] :, :] = k # (b, n_ctx_cache + n_ctx, n_state) + v_cache[:, -v.shape[1] :, :] = v # (b, n_ctx_cache + n_ctx, n_state) + + wv, qk = self.multiHeadAttention.qkv_attention(q, k_cache, v_cache, mask) + return self.multiHeadAttention.out(wv), k_cache, v_cache + + +class ResidualAttentionBlockTensorCache(nn.Module): + def __init__(self, inResidualAttentionBlock: ResidualAttentionBlock): + super().__init__() + self.originalBlock = inResidualAttentionBlock + self.attn = MultiHeadAttentionSelf(inResidualAttentionBlock.attn) + self.cross_attn = ( + MultiHeadAttentionCross(inResidualAttentionBlock.cross_attn) + if inResidualAttentionBlock.cross_attn + else None + ) + + def forward( + self, + x: Tensor, + self_k_cache: Tensor, + self_v_cache: Tensor, + cross_k: Tensor, + cross_v: Tensor, + mask: Tensor, + ): + self_attn_x, self_k_cache_updated, self_v_cache_updated = self.attn( + self.originalBlock.attn_ln(x), self_k_cache, self_v_cache, mask=mask + ) + x = x + self_attn_x + + if self.cross_attn: + x = x + self.cross_attn( + self.originalBlock.cross_attn_ln(x), cross_k, cross_v + ) + + x = x + self.originalBlock.mlp(self.originalBlock.mlp_ln(x)) + return x, self_k_cache_updated, self_v_cache_updated + + +class TextDecoderTensorCache(nn.Module): + def __init__(self, inTextDecoder: TextDecoder, in_n_ctx: int): + super().__init__() + self.textDecoder = inTextDecoder + self.n_ctx = in_n_ctx + + self.blocks = [] + for orginal_block in self.textDecoder.blocks: + self.blocks.append(ResidualAttentionBlockTensorCache(orginal_block)) + + def forward( + self, + tokens: Tensor, + n_layer_self_k_cache: Tensor, + n_layer_self_v_cache: Tensor, + n_layer_cross_k: Tensor, + n_layer_cross_v: Tensor, + offset: Tensor, + ): + x = ( + self.textDecoder.token_embedding(tokens) + + self.textDecoder.positional_embedding[ + offset[0] : offset[0] + tokens.shape[-1] + ] + ) + x = x.to(n_layer_cross_k[0].dtype) + + i = 0 + for block in self.blocks: + self_k_cache = n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] + self_v_cache = n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] + x, self_k_cache, self_v_cache = block( + x, + self_k_cache=self_k_cache, + self_v_cache=self_v_cache, + cross_k=n_layer_cross_k[i], + cross_v=n_layer_cross_v[i], + mask=self.textDecoder.mask, + ) + n_layer_self_k_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_k_cache + n_layer_self_v_cache[i, :, : offset[0] + tokens.shape[-1], :] = self_v_cache + i += 1 + + x = self.textDecoder.ln(x) + + logits = ( + x + @ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1) + ).float() + + return logits, n_layer_self_k_cache, n_layer_self_v_cache + + +# ref: https://github.com/ggerganov/whisper.cpp/blob/master/models/convert-pt-to-ggml.py#L232 +def convert_tokens(name, model): + whisper_dir = Path(whisper.__file__).parent + multilingual = model.is_multilingual + tokenizer = ( + whisper_dir + / "assets" + / (multilingual and "multilingual.tiktoken" or "gpt2.tiktoken") + ) + if not tokenizer.is_file(): + raise ValueError(f"Cannot find {tokenizer}") + + # import base64 + + with open(tokenizer, "r") as f: + contents = f.read() + # tokens = { + # base64.b64decode(token): int(rank) + # for token, rank in (line.split() for line in contents.splitlines() if line) + # } + tokens = { + token: int(rank) + for token, rank in (line.split() for line in contents.splitlines() if line) + } + + with open(f"{name}-tokens.txt", "w") as f: + for t, i in tokens.items(): + f.write(f"{t} {i}\n") + + +@torch.no_grad() +def main(): + args = get_args() + name = args.model + + opset_version = 13 + + model = whisper.load_model(name) + convert_tokens(name=name, model=model) + + # write tokens + + tokenizer = whisper.tokenizer.get_tokenizer(model.is_multilingual) + model.eval() + print(model.dims) + audio = torch.rand(16000 * 2) + audio = whisper.pad_or_trim(audio) + assert audio.shape == (16000 * 30,), audio.shape + + # make log-Mel spectrogram and move to the same device as the model + mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0) + batch_size = 1 + assert mel.shape == (batch_size, 80, 30 * 100) + + encoder = AudioEncoderTensorCache(model.encoder, model.decoder) + n_layer_cross_k, n_layer_cross_v = encoder(mel) + assert n_layer_cross_k.shape == ( + model.dims.n_text_layer, + batch_size, + model.dims.n_audio_ctx, + model.dims.n_text_state, + ), n_layer_cross_k.shape + assert n_layer_cross_v.shape == ( + model.dims.n_text_layer, + batch_size, + model.dims.n_audio_ctx, + model.dims.n_text_state, + ), n_layer_cross_v.shape + + encoder_filename = f"{name}-encoder.onnx" + torch.onnx.export( + encoder, + mel, + encoder_filename, + opset_version=opset_version, + input_names=["mel"], + output_names=["n_layer_cross_k", "n_layer_cross_v"], + dynamic_axes={ + "mel": {0: "n_audio"}, # n_audio is also known as batch_size + "n_layer_cross_k": {1: "n_audio"}, + "n_layer_cross_v": {1: "n_audio"}, + }, + ) + + encoder_meta_data = { + "model_type": f"whisper-{name}", + "version": "1", + "maintainer": "k2-fsa", + "n_mels": model.dims.n_mels, + "n_audio_ctx": model.dims.n_audio_ctx, + "n_audio_state": model.dims.n_audio_state, + "n_audio_head": model.dims.n_audio_head, + "n_audio_layer": model.dims.n_audio_layer, + "n_vocab": model.dims.n_vocab, + "n_text_ctx": model.dims.n_text_ctx, + "n_text_state": model.dims.n_text_state, + "n_text_head": model.dims.n_text_head, + "n_text_layer": model.dims.n_text_layer, + "sot_sequence": ",".join(list(map(str, tokenizer.sot_sequence))), + "all_language_tokens": ",".join(list(map(str, tokenizer.all_language_tokens))), + "all_language_codes": ",".join(tokenizer.all_language_codes), + "sot": tokenizer.sot, + "sot_index": tokenizer.sot_sequence.index(tokenizer.sot), + "eot": tokenizer.eot, + "blank_id": tokenizer.encode(" ")[0], + "is_multilingual": int(model.is_multilingual), + "no_speech": tokenizer.no_speech, + "non_speech_tokens": ",".join(list(map(str, tokenizer.non_speech_tokens))), + "transcribe": tokenizer.transcribe, + "translate": tokenizer.translate, + "sot_prev": tokenizer.sot_prev, + "sot_lm": tokenizer.sot_lm, + "no_timestamps": tokenizer.no_timestamps, + } + print(f"encoder_meta_data: {encoder_meta_data}") + add_meta_data(filename=encoder_filename, meta_data=encoder_meta_data) + + n_audio = mel.shape[0] + tokens = torch.tensor([[tokenizer.sot, tokenizer.sot, tokenizer.sot]] * n_audio).to( + mel.device + ) # [n_audio, 3] + decoder = TextDecoderTensorCache(model.decoder, model.dims.n_text_ctx) + n_layer_self_k_cache = torch.zeros( + ( + len(model.decoder.blocks), + n_audio, + model.dims.n_text_ctx, + model.dims.n_text_state, + ), + device=mel.device, + ) + n_layer_self_v_cache = torch.zeros( + ( + len(model.decoder.blocks), + n_audio, + model.dims.n_text_ctx, + model.dims.n_text_state, + ), + device=mel.device, + ) + offset = torch.zeros(1, dtype=torch.int64).to(mel.device) + logits, n_layer_self_k_cache, n_layer_self_v_cache = decoder( + tokens, + n_layer_self_k_cache, + n_layer_self_v_cache, + n_layer_cross_k, + n_layer_cross_v, + offset, + ) + assert logits.shape == (n_audio, tokens.shape[1], model.dims.n_vocab) + assert n_layer_self_k_cache.shape == ( + model.dims.n_text_layer, + n_audio, + model.dims.n_text_ctx, + model.dims.n_text_state, + ) + assert n_layer_self_v_cache.shape == ( + model.dims.n_text_layer, + n_audio, + model.dims.n_text_ctx, + model.dims.n_text_state, + ) + + offset = torch.tensor([tokens.shape[1]], dtype=torch.int64).to(mel.device) + tokens = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1] + + logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = decoder( + tokens, + n_layer_self_k_cache, + n_layer_self_v_cache, + n_layer_cross_k, + n_layer_cross_v, + offset, + ) + + decoder_filename = f"{name}-decoder.onnx" + torch.onnx.export( + decoder, + ( + tokens, + n_layer_self_k_cache, + n_layer_self_v_cache, + n_layer_cross_k, + n_layer_cross_v, + offset, + ), + decoder_filename, + opset_version=opset_version, + input_names=[ + "tokens", + "in_n_layer_self_k_cache", + "in_n_layer_self_v_cache", + "n_layer_cross_k", + "n_layer_cross_v", + "offset", + ], + output_names=["logits", "out_n_layer_self_k_cache", "out_n_layer_self_v_cache"], + dynamic_axes={ + "tokens": {0: "n_audio", 1: "n_tokens"}, + "in_n_layer_self_k_cache": {1: "n_audio"}, + "in_n_layer_self_v_cache": {1: "n_audio"}, + "n_layer_cross_k": {1: "n_audio"}, + "n_layer_cross_v": {1: "n_audio"}, + }, + ) + + # Generate int8 quantization models + # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection + + print("Generate int8 quantization models") + + encoder_filename_int8 = f"{name}-encoder.int8.onnx" + quantize_dynamic( + model_input=encoder_filename, + model_output=encoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + decoder_filename_int8 = f"{name}-decoder.int8.onnx" + quantize_dynamic( + model_input=decoder_filename, + model_output=decoder_filename_int8, + op_types_to_quantize=["MatMul"], + weight_type=QuantType.QInt8, + ) + + +if __name__ == "__main__": + main() diff --git a/scripts/whisper/requirements.txt b/scripts/whisper/requirements.txt new file mode 100644 index 000000000..73bca28e3 --- /dev/null +++ b/scripts/whisper/requirements.txt @@ -0,0 +1 @@ +openai-whisper diff --git a/scripts/whisper/test.py b/scripts/whisper/test.py new file mode 100755 index 000000000..d1c422bff --- /dev/null +++ b/scripts/whisper/test.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) +""" +Please first run ./export-onnx.py +before you run this script +""" +import base64 +from typing import Tuple + +import kaldi_native_fbank as knf +import onnxruntime as ort +import torch + +import whisper +import argparse + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + required=True, + # fmt: off + choices=[ + "tiny", "tiny.en", "base", "base.en", + "small", "small.en", "medium", "medium.en", + "large", "large-v1", "large-v2"], + # fmt: on + ) + return parser.parse_args() + + +class OnnxModel: + def __init__( + self, + encoder: str, + decoder: str, + ): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 1 + session_opts.intra_op_num_threads = 4 + + self.session_opts = session_opts + + self.init_encoder(encoder) + self.init_decoder(decoder) + + def init_encoder(self, encoder: str): + self.encoder = ort.InferenceSession( + encoder, + sess_options=self.session_opts, + ) + + meta = self.encoder.get_modelmeta().custom_metadata_map + self.n_text_layer = int(meta["n_text_layer"]) + self.n_text_ctx = int(meta["n_text_ctx"]) + self.n_text_state = int(meta["n_text_state"]) + self.sot = int(meta["sot"]) + self.eot = int(meta["eot"]) + self.translate = int(meta["translate"]) + self.no_timestamps = int(meta["no_timestamps"]) + self.no_speech = int(meta["no_speech"]) + self.blank = int(meta["blank_id"]) + + self.sot_sequence = list(map(int, meta["sot_sequence"].split(","))) + + self.is_multilingual = int(meta["is_multilingual"]) == 1 + + def init_decoder(self, decoder: str): + self.decoder = ort.InferenceSession( + decoder, + sess_options=self.session_opts, + ) + + def run_encoder( + self, + mel: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + n_layer_cross_k, n_layer_cross_v = self.encoder.run( + [ + self.encoder.get_outputs()[0].name, + self.encoder.get_outputs()[1].name, + ], + { + self.encoder.get_inputs()[0].name: mel.numpy(), + }, + ) + return torch.from_numpy(n_layer_cross_k), torch.from_numpy(n_layer_cross_v) + + def run_decoder( + self, + tokens: torch.Tensor, + n_layer_self_k_cache: torch.Tensor, + n_layer_self_v_cache: torch.Tensor, + n_layer_cross_k: torch.Tensor, + n_layer_cross_v: torch.Tensor, + offset: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache = self.decoder.run( + [ + self.decoder.get_outputs()[0].name, + self.decoder.get_outputs()[1].name, + self.decoder.get_outputs()[2].name, + ], + { + self.decoder.get_inputs()[0].name: tokens.numpy(), + self.decoder.get_inputs()[1].name: n_layer_self_k_cache.numpy(), + self.decoder.get_inputs()[2].name: n_layer_self_v_cache.numpy(), + self.decoder.get_inputs()[3].name: n_layer_cross_k.numpy(), + self.decoder.get_inputs()[4].name: n_layer_cross_v.numpy(), + self.decoder.get_inputs()[5].name: offset.numpy(), + }, + ) + return ( + torch.from_numpy(logits), + torch.from_numpy(out_n_layer_self_k_cache), + torch.from_numpy(out_n_layer_self_v_cache), + ) + + def get_self_cache(self) -> Tuple[torch.Tensor, torch.Tensor]: + batch_size = 1 + n_layer_self_k_cache = torch.zeros( + self.n_text_layer, + batch_size, + self.n_text_ctx, + self.n_text_state, + ) + n_layer_self_v_cache = torch.zeros( + self.n_text_layer, + batch_size, + self.n_text_ctx, + self.n_text_state, + ) + return n_layer_self_k_cache, n_layer_self_v_cache + + def suppress_tokens(self, logits, is_initial: bool) -> None: + # suppress blank + if is_initial: + logits[self.eot] = float("-inf") + logits[self.blank] = float("-inf") + + # suppress <|notimestamps|> + logits[self.no_timestamps] = float("-inf") + + logits[self.sot] = float("-inf") + logits[self.no_speech] = float("-inf") + + # logits is changed in-place + logits[self.translate] = float("-inf") + + +def load_tokens(filename): + tokens = dict() + with open(filename, "r") as f: + for line in f: + t, i = line.split() + tokens[int(i)] = t + return tokens + + +def main(): + args = get_args() + name = args.model + + encoder = f"./{name}-encoder.onnx" + decoder = f"./{name}-decoder.onnx" + audio = whisper.load_audio("0.wav") + + features = [] + online_whisper_fbank = knf.OnlineWhisperFbank(knf.FrameExtractionOptions()) + online_whisper_fbank.accept_waveform(16000, audio) + online_whisper_fbank.input_finished() + for i in range(online_whisper_fbank.num_frames_ready): + f = online_whisper_fbank.get_frame(i) + f = torch.from_numpy(f) + features.append(f) + + features = torch.stack(features) + + log_spec = torch.clamp(features, min=1e-10).log10() + log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + mel = (log_spec + 4.0) / 4.0 + target = 3000 + mel = torch.nn.functional.pad(mel, (0, 0, 0, target - mel.shape[0]), "constant", 0) + mel = mel.t().unsqueeze(0) + + model = OnnxModel(encoder, decoder) + n_layer_cross_k, n_layer_cross_v = model.run_encoder(mel) + n_layer_self_k_cache, n_layer_self_v_cache = model.get_self_cache() + + tokens = torch.tensor([model.sot_sequence], dtype=torch.int64) + offset = torch.zeros(1, dtype=torch.int64) + logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( + tokens=tokens, + n_layer_self_k_cache=n_layer_self_k_cache, + n_layer_self_v_cache=n_layer_self_v_cache, + n_layer_cross_k=n_layer_cross_k, + n_layer_cross_v=n_layer_cross_v, + offset=offset, + ) + # logits.shape (batch_size, tokens.shape[1], vocab_size) + logits = logits[0, -1] + model.suppress_tokens(logits, is_initial=True) + # logits = logits.softmax(dim=-1) + # for greedy search, we don't need to compute softmax or log_softmax + max_token_id = logits.argmax(dim=-1) + results = [] + for i in range(model.n_text_ctx): + if max_token_id == model.eot: + break + results.append(max_token_id.item()) + tokens = torch.tensor([[results[-1]]]) + offset += 1 + + logits, n_layer_self_k_cache, n_layer_self_v_cache = model.run_decoder( + tokens=tokens, + n_layer_self_k_cache=n_layer_self_k_cache, + n_layer_self_v_cache=n_layer_self_v_cache, + n_layer_cross_k=n_layer_cross_k, + n_layer_cross_v=n_layer_cross_v, + offset=offset, + ) + logits = logits[0, -1] + model.suppress_tokens(logits, is_initial=False) + max_token_id = logits.argmax(dim=-1) + token_table = load_tokens(f"./{name}-tokens.txt") + s = b"" + for i in results: + if i in token_table: + s += base64.b64decode(token_table[i]) + else: + print("oov", i) + + print(s.decode().strip()) + print(results) + print(model.sot_sequence) + + +if __name__ == "__main__": + main() diff --git a/sherpa-onnx/csrc/CMakeLists.txt b/sherpa-onnx/csrc/CMakeLists.txt index 9a431ba0f..9060c0ac5 100644 --- a/sherpa-onnx/csrc/CMakeLists.txt +++ b/sherpa-onnx/csrc/CMakeLists.txt @@ -11,6 +11,7 @@ if(SHERPA_ONNX_ENABLE_PYTHON) endif() set(sources + base64-decode.cc cat.cc context-graph.cc endpoint.cc @@ -35,6 +36,9 @@ set(sources offline-transducer-model-config.cc offline-transducer-model.cc offline-transducer-modified-beam-search-decoder.cc + offline-whisper-greedy-search-decoder.cc + offline-whisper-model-config.cc + offline-whisper-model.cc online-conformer-transducer-model.cc online-lm-config.cc online-lm.cc @@ -50,12 +54,12 @@ set(sources online-zipformer-transducer-model.cc online-zipformer2-transducer-model.cc onnx-utils.cc - session.cc packed-sequence.cc pad-sequence.cc parse-options.cc provider.cc resample.cc + session.cc slice.cc stack.cc symbol-table.cc diff --git a/sherpa-onnx/csrc/base64-decode.cc b/sherpa-onnx/csrc/base64-decode.cc new file mode 100644 index 000000000..b22e443e1 --- /dev/null +++ b/sherpa-onnx/csrc/base64-decode.cc @@ -0,0 +1,67 @@ +// sherpa-onnx/csrc/base64-decode.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/base64-decode.h" + +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +static int32_t Ord(char c) { + if (c >= 'A' && c <= 'Z') { + return c - 'A'; + } else if (c >= 'a' && c <= 'z') { + return c - 'a' + ('Z' - 'A') + 1; + } else if (c >= '0' && c <= '9') { + return c - '0' + ('Z' - 'A') + ('z' - 'a') + 2; + } else if (c == '+') { + return 62; + } else if (c == '/') { + return 63; + } + + SHERPA_ONNX_LOGE("Unknown character %d, %c\n", c, c); + + exit(-1); +} + +// see +// https://github.com/ReneNyffenegger/cpp-base64/blob/master/base64.cpp#L243 +std::string Base64Decode(const std::string &s) { + if (s.empty()) { + SHERPA_ONNX_LOGE("Empty string!"); + exit(-1); + } + + int32_t n = s.size() / 4 * 3; + + std::string ans; + ans.reserve(n); + + int32_t i = 0; + while (i < static_cast(s.size())) { + if (s[i] == '=') { + return " "; + } + + int32_t first = (Ord(s[i]) << 2) + ((Ord(s[i + 1]) & 0x30) >> 4); + ans.push_back(first); + + if (i + 2 < static_cast(s.size()) && s[i + 2] != '=') { + int32_t second = + ((Ord(s[i + 1]) & 0x0f) << 4) + ((Ord(s[i + 2]) & 0x3c) >> 2); + ans.push_back(second); + + if (i + 3 < static_cast(s.size()) && s[i + 3] != '=') { + int32_t third = ((Ord(s[i + 2]) & 0x03) << 6) + Ord(s[i + 3]); + ans.push_back(third); + } + } + i += 4; + } + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/base64-decode.h b/sherpa-onnx/csrc/base64-decode.h new file mode 100644 index 000000000..3b2f9a34d --- /dev/null +++ b/sherpa-onnx/csrc/base64-decode.h @@ -0,0 +1,19 @@ +// sherpa-onnx/csrc/base64-decode.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_BASE64_DECODE_H_ +#define SHERPA_ONNX_CSRC_BASE64_DECODE_H_ + +#include + +namespace sherpa_onnx { + +/** @param s A base64 encoded string. + * @return Return the decoded string. + */ +std::string Base64Decode(const std::string &s); + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_BASE64_DECODE_H_ diff --git a/sherpa-onnx/csrc/macros.h b/sherpa-onnx/csrc/macros.h index 685d34ac0..93e0accdb 100644 --- a/sherpa-onnx/csrc/macros.h +++ b/sherpa-onnx/csrc/macros.h @@ -1,4 +1,3 @@ - // sherpa-onnx/csrc/macros.h // // Copyright 2023 Xiaomi Corporation diff --git a/sherpa-onnx/csrc/offline-model-config.cc b/sherpa-onnx/csrc/offline-model-config.cc index 92380e76f..9808d8f6d 100644 --- a/sherpa-onnx/csrc/offline-model-config.cc +++ b/sherpa-onnx/csrc/offline-model-config.cc @@ -14,6 +14,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { transducer.Register(po); paraformer.Register(po); nemo_ctc.Register(po); + whisper.Register(po); po->Register("tokens", &tokens, "Path to tokens.txt"); @@ -28,7 +29,7 @@ void OfflineModelConfig::Register(ParseOptions *po) { po->Register("model-type", &model_type, "Specify it to reduce model initialization time. " - "Valid values are: transducer, paraformer, nemo_ctc. " + "Valid values are: transducer, paraformer, nemo_ctc, whisper." "All other values lead to loading the model twice."); } @@ -51,6 +52,10 @@ bool OfflineModelConfig::Validate() const { return nemo_ctc.Validate(); } + if (!whisper.encoder.empty()) { + return whisper.Validate(); + } + return transducer.Validate(); } @@ -61,6 +66,7 @@ std::string OfflineModelConfig::ToString() const { os << "transducer=" << transducer.ToString() << ", "; os << "paraformer=" << paraformer.ToString() << ", "; os << "nemo_ctc=" << nemo_ctc.ToString() << ", "; + os << "whisper=" << whisper.ToString() << ", "; os << "tokens=\"" << tokens << "\", "; os << "num_threads=" << num_threads << ", "; os << "debug=" << (debug ? "True" : "False") << ", "; diff --git a/sherpa-onnx/csrc/offline-model-config.h b/sherpa-onnx/csrc/offline-model-config.h index 4afdd65a8..41f441c94 100644 --- a/sherpa-onnx/csrc/offline-model-config.h +++ b/sherpa-onnx/csrc/offline-model-config.h @@ -9,6 +9,7 @@ #include "sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/csrc/offline-paraformer-model-config.h" #include "sherpa-onnx/csrc/offline-transducer-model-config.h" +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" namespace sherpa_onnx { @@ -16,6 +17,7 @@ struct OfflineModelConfig { OfflineTransducerModelConfig transducer; OfflineParaformerModelConfig paraformer; OfflineNemoEncDecCtcModelConfig nemo_ctc; + OfflineWhisperModelConfig whisper; std::string tokens; int32_t num_threads = 2; @@ -37,11 +39,13 @@ struct OfflineModelConfig { OfflineModelConfig(const OfflineTransducerModelConfig &transducer, const OfflineParaformerModelConfig ¶former, const OfflineNemoEncDecCtcModelConfig &nemo_ctc, + const OfflineWhisperModelConfig &whisper, const std::string &tokens, int32_t num_threads, bool debug, const std::string &provider, const std::string &model_type) : transducer(transducer), paraformer(paraformer), nemo_ctc(nemo_ctc), + whisper(whisper), tokens(tokens), num_threads(num_threads), debug(debug), diff --git a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc index c28c522d1..5589402ee 100644 --- a/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc +++ b/sherpa-onnx/csrc/offline-nemo-enc-dec-ctc-model-config.cc @@ -16,7 +16,7 @@ void OfflineNemoEncDecCtcModelConfig::Register(ParseOptions *po) { bool OfflineNemoEncDecCtcModelConfig::Validate() const { if (!FileExists(model)) { - SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("NeMo model: %s does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-paraformer-model-config.cc b/sherpa-onnx/csrc/offline-paraformer-model-config.cc index dad43ff69..82886fe87 100644 --- a/sherpa-onnx/csrc/offline-paraformer-model-config.cc +++ b/sherpa-onnx/csrc/offline-paraformer-model-config.cc @@ -15,7 +15,7 @@ void OfflineParaformerModelConfig::Register(ParseOptions *po) { bool OfflineParaformerModelConfig::Validate() const { if (!FileExists(model)) { - SHERPA_ONNX_LOGE("%s does not exist", model.c_str()); + SHERPA_ONNX_LOGE("Paraformer model %s does not exist", model.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-recognizer-impl.cc b/sherpa-onnx/csrc/offline-recognizer-impl.cc index 8a2b42a08..5058a8ce2 100644 --- a/sherpa-onnx/csrc/offline-recognizer-impl.cc +++ b/sherpa-onnx/csrc/offline-recognizer-impl.cc @@ -11,6 +11,7 @@ #include "sherpa-onnx/csrc/offline-recognizer-ctc-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-paraformer-impl.h" #include "sherpa-onnx/csrc/offline-recognizer-transducer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer-whisper-impl.h" #include "sherpa-onnx/csrc/onnx-utils.h" #include "sherpa-onnx/csrc/text-utils.h" @@ -26,6 +27,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } else if (model_type == "nemo_ctc") { return std::make_unique(config); + } else if (model_type == "whisper") { + return std::make_unique(config); } else { SHERPA_ONNX_LOGE( "Invalid model_type: %s. Trying to load the model to get its type", @@ -43,6 +46,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( model_filename = config.model_config.paraformer.model; } else if (!config.model_config.nemo_ctc.model.empty()) { model_filename = config.model_config.nemo_ctc.model; + } else if (!config.model_config.whisper.encoder.empty()) { + model_filename = config.model_config.whisper.encoder; } else { SHERPA_ONNX_LOGE("Please provide a model"); exit(-1); @@ -77,6 +82,8 @@ std::unique_ptr OfflineRecognizerImpl::Create( "\n " "https://huggingface.co/csukuangfj/" "paraformer-onnxruntime-python-example/blob/main/add-model-metadata.py" + "\n " + "(3) Whisper" "\n"); exit(-1); } @@ -95,12 +102,17 @@ std::unique_ptr OfflineRecognizerImpl::Create( return std::make_unique(config); } + if (strncmp(model_type.c_str(), "whisper", 7) == 0) { + return std::make_unique(config); + } + SHERPA_ONNX_LOGE( "\nUnsupported model_type: %s\n" "We support only the following model types at present: \n" " - Non-streaming transducer models from icefall\n" " - Non-streaming Paraformer models from FunASR\n" - " - EncDecCTCModelBPE models from NeMo\n", + " - EncDecCTCModelBPE models from NeMo\n" + " - Whisper models\n", model_type.c_str()); exit(-1); diff --git a/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h new file mode 100644 index 000000000..efe9da282 --- /dev/null +++ b/sherpa-onnx/csrc/offline-recognizer-whisper-impl.h @@ -0,0 +1,152 @@ +// sherpa-onnx/csrc/offline-recognizer-whisper-impl.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ + +#include +#include +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/offline-model-config.h" +#include "sherpa-onnx/csrc/offline-recognizer-impl.h" +#include "sherpa-onnx/csrc/offline-recognizer.h" +#include "sherpa-onnx/csrc/offline-whisper-decoder.h" +#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h" +#include "sherpa-onnx/csrc/offline-whisper-model.h" +#include "sherpa-onnx/csrc/symbol-table.h" +#include "sherpa-onnx/csrc/transpose.h" + +namespace sherpa_onnx { + +static OfflineRecognitionResult Convert(const OfflineWhisperDecoderResult &src, + const SymbolTable &sym_table) { + OfflineRecognitionResult r; + r.tokens.reserve(src.tokens.size()); + + for (auto i : src.tokens) { + if (!sym_table.contains(i)) { + continue; + } + + const auto &s = sym_table[i]; + r.text += s; + r.tokens.push_back(s); + } + + return r; +} + +class OfflineRecognizerWhisperImpl : public OfflineRecognizerImpl { + public: + explicit OfflineRecognizerWhisperImpl(const OfflineRecognizerConfig &config) + : config_(config), + symbol_table_(config_.model_config.tokens), + model_(std::make_unique(config.model_config)) { + // tokens.txt from whisper is base64 encoded, so we need to decode it + symbol_table_.ApplyBase64Decode(); + + if (config.decoding_method == "greedy_search") { + decoder_ = + std::make_unique(model_.get()); + } else { + SHERPA_ONNX_LOGE( + "Only greedy_search is supported at present for whisper. Given %s", + config.decoding_method.c_str()); + exit(-1); + } + } + + std::unique_ptr CreateStream() const override { + return std::make_unique(WhisperTag{}); + } + + void DecodeStreams(OfflineStream **ss, int32_t n) const override { + // batch decoding is not implemented yet + for (int32_t i = 0; i != n; ++i) { + DecodeStream(ss[i]); + } + } + + private: + void DecodeStream(OfflineStream *s) const { + int32_t max_num_frames = 3000; + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + int32_t feat_dim = s->FeatureDim(); + std::vector f = s->GetFrames(); + int32_t num_frames = f.size() / feat_dim; + + if (num_frames > max_num_frames) { + SHERPA_ONNX_LOGE("Only waves less than 30 seconds are supported."); + exit(-1); + } + + NormalizeFeatures(f.data(), num_frames, feat_dim); + + std::array shape{1, max_num_frames, feat_dim}; + + Ort::Value mel = Ort::Value::CreateTensor( + model_->Allocator(), shape.data(), shape.size()); + float *p_mel = mel.GetTensorMutableData(); + std::copy(f.begin(), f.end(), p_mel); + + memset(p_mel + f.size(), 0, + (max_num_frames - num_frames) * feat_dim * sizeof(float)); + mel = Transpose12(model_->Allocator(), &mel); + + auto cross_kv = model_->ForwardEncoder(std::move(mel)); + auto results = + decoder_->Decode(std::move(cross_kv.first), std::move(cross_kv.second)); + + auto r = Convert(results[0], symbol_table_); + s->SetResult(r); + } + + private: + static void NormalizeFeatures(float *features, int32_t num_frames, + int32_t feat_dim) { + // log_spec = torch.clamp(features, min=1e-10).log10() + // log_spec = torch.maximum(log_spec, log_spec.max() - 8.0) + // mel = (log_spec + 4.0) / 4.0 + + int32_t n = num_frames * feat_dim; + float max_v = -1e20; + for (int32_t i = 0; i != n; ++i) { + float f = features[i]; + + f = std::max(f, 1e-10); + f = std::log10(f); + + max_v = std::max(f, max_v); + + features[i] = f; + } + + max_v -= 8; + + for (int32_t i = 0; i != n; ++i) { + float f = features[i]; + f = std::max(f, max_v); + + f = (f + 4) / 4; + + features[i] = f; + } + } + + private: + OfflineRecognizerConfig config_; + SymbolTable symbol_table_; + std::unique_ptr model_; + std::unique_ptr decoder_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_RECOGNIZER_WHISPER_IMPL_H_ diff --git a/sherpa-onnx/csrc/offline-stream.cc b/sherpa-onnx/csrc/offline-stream.cc index 15ed0389c..e317d8d5c 100644 --- a/sherpa-onnx/csrc/offline-stream.cc +++ b/sherpa-onnx/csrc/offline-stream.cc @@ -86,6 +86,15 @@ class OfflineStream::Impl { fbank_ = std::make_unique(opts_); } + Impl(WhisperTag /*tag*/, ContextGraphPtr context_graph) + : context_graph_(context_graph) { + config_.normalize_samples = true; + opts_.frame_opts.samp_freq = 16000; + opts_.mel_opts.num_bins = 80; + whisper_fbank_ = + std::make_unique(opts_.frame_opts); + } + void AcceptWaveform(int32_t sampling_rate, const float *waveform, int32_t n) { if (config_.normalize_samples) { AcceptWaveformImpl(sampling_rate, waveform, n); @@ -117,20 +126,35 @@ class OfflineStream::Impl { lowpass_filter_width); std::vector samples; resampler->Resample(waveform, n, true, &samples); - fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), - samples.size()); - fbank_->InputFinished(); + + if (fbank_) { + fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, samples.data(), + samples.size()); + fbank_->InputFinished(); + } else { + whisper_fbank_->AcceptWaveform(opts_.frame_opts.samp_freq, + samples.data(), samples.size()); + whisper_fbank_->InputFinished(); + } + return; - } + } // if (sampling_rate != opts_.frame_opts.samp_freq) - fbank_->AcceptWaveform(sampling_rate, waveform, n); - fbank_->InputFinished(); + if (fbank_) { + fbank_->AcceptWaveform(sampling_rate, waveform, n); + fbank_->InputFinished(); + } else { + whisper_fbank_->AcceptWaveform(sampling_rate, waveform, n); + whisper_fbank_->InputFinished(); + } } int32_t FeatureDim() const { return opts_.mel_opts.num_bins; } std::vector GetFrames() const { - int32_t n = fbank_->NumFramesReady(); + int32_t n = + fbank_ ? fbank_->NumFramesReady() : whisper_fbank_->NumFramesReady(); + assert(n > 0 && "Please first call AcceptWaveform()"); int32_t feature_dim = FeatureDim(); @@ -140,7 +164,8 @@ class OfflineStream::Impl { float *p = features.data(); for (int32_t i = 0; i != n; ++i) { - const float *f = fbank_->GetFrame(i); + const float *f = + fbank_ ? fbank_->GetFrame(i) : whisper_fbank_->GetFrame(i); std::copy(f, f + feature_dim, p); p += feature_dim; } @@ -191,6 +216,7 @@ class OfflineStream::Impl { private: OfflineFeatureExtractorConfig config_; std::unique_ptr fbank_; + std::unique_ptr whisper_fbank_; knf::FbankOptions opts_; OfflineRecognitionResult r_; ContextGraphPtr context_graph_; @@ -201,6 +227,10 @@ OfflineStream::OfflineStream( ContextGraphPtr context_graph /*= nullptr*/) : impl_(std::make_unique(config, context_graph)) {} +OfflineStream::OfflineStream(WhisperTag tag, + ContextGraphPtr context_graph /*= nullptr*/) + : impl_(std::make_unique(tag, context_graph)) {} + OfflineStream::~OfflineStream() = default; void OfflineStream::AcceptWaveform(int32_t sampling_rate, const float *waveform, diff --git a/sherpa-onnx/csrc/offline-stream.h b/sherpa-onnx/csrc/offline-stream.h index a21496bda..6eee0e545 100644 --- a/sherpa-onnx/csrc/offline-stream.h +++ b/sherpa-onnx/csrc/offline-stream.h @@ -65,10 +65,15 @@ struct OfflineFeatureExtractorConfig { void Register(ParseOptions *po); }; +struct WhisperTag {}; + class OfflineStream { public: explicit OfflineStream(const OfflineFeatureExtractorConfig &config = {}, ContextGraphPtr context_graph = nullptr); + + explicit OfflineStream(WhisperTag tag, + ContextGraphPtr context_graph = nullptr); ~OfflineStream(); /** diff --git a/sherpa-onnx/csrc/offline-transducer-model-config.cc b/sherpa-onnx/csrc/offline-transducer-model-config.cc index b90d68e7c..05fcc9092 100644 --- a/sherpa-onnx/csrc/offline-transducer-model-config.cc +++ b/sherpa-onnx/csrc/offline-transducer-model-config.cc @@ -18,17 +18,20 @@ void OfflineTransducerModelConfig::Register(ParseOptions *po) { bool OfflineTransducerModelConfig::Validate() const { if (!FileExists(encoder_filename)) { - SHERPA_ONNX_LOGE("encoder: %s does not exist", encoder_filename.c_str()); + SHERPA_ONNX_LOGE("transducer encoder: %s does not exist", + encoder_filename.c_str()); return false; } if (!FileExists(decoder_filename)) { - SHERPA_ONNX_LOGE("decoder: %s does not exist", decoder_filename.c_str()); + SHERPA_ONNX_LOGE("transducer decoder: %s does not exist", + decoder_filename.c_str()); return false; } if (!FileExists(joiner_filename)) { - SHERPA_ONNX_LOGE("joiner: %s does not exist", joiner_filename.c_str()); + SHERPA_ONNX_LOGE("transducer joiner: %s does not exist", + joiner_filename.c_str()); return false; } diff --git a/sherpa-onnx/csrc/offline-whisper-decoder.h b/sherpa-onnx/csrc/offline-whisper-decoder.h new file mode 100644 index 000000000..c9367eafd --- /dev/null +++ b/sherpa-onnx/csrc/offline-whisper-decoder.h @@ -0,0 +1,38 @@ +// sherpa-onnx/csrc/offline-whisper-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ + +#include + +#include "onnxruntime_cxx_api.h" // NOLINT + +namespace sherpa_onnx { + +struct OfflineWhisperDecoderResult { + /// The decoded token IDs + std::vector tokens; +}; + +class OfflineWhisperDecoder { + public: + virtual ~OfflineWhisperDecoder() = default; + + /** Run beam search given the output from the whisper encoder model. + * + * @param n_layer_cross_k A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * @param n_layer_cross_v A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * + * @return Return a vector of size `N` containing the decoded results. + */ + virtual std::vector Decode( + Ort::Value n_layer_cross_k, Ort::Value n_layer_cross_v) = 0; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc new file mode 100644 index 000000000..1b2213002 --- /dev/null +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc @@ -0,0 +1,93 @@ +// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h" + +#include +#include + +namespace sherpa_onnx { + +std::vector +OfflineWhisperGreedySearchDecoder::Decode(Ort::Value cross_k, + Ort::Value cross_v) { + auto memory_info = + Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeDefault); + + auto self_kv_cache = model_->GetInitialSelfKVCache(); + + std::vector initial_tokens = model_->GetInitialTokens(); + int32_t batch_size = 1; + std::array token_shape{ + batch_size, static_cast(initial_tokens.size())}; + + Ort::Value tokens = Ort::Value::CreateTensor( + memory_info, initial_tokens.data(), initial_tokens.size(), + token_shape.data(), token_shape.size()); + + std::array offset_shape{1}; + Ort::Value offset = Ort::Value::CreateTensor( + model_->Allocator(), offset_shape.data(), offset_shape.size()); + *(offset.GetTensorMutableData()) = 0; + + auto decoder_out = model_->ForwardDecoder( + std::move(tokens), std::move(self_kv_cache.first), + std::move(self_kv_cache.second), std::move(cross_k), std::move(cross_v), + std::move(offset)); + + const auto &logits = std::get<0>(decoder_out); + const float *p_logits = logits.GetTensorData(); + + auto logits_shape = logits.GetTensorTypeAndShapeInfo().GetShape(); + int32_t vocab_size = logits_shape[2]; + + int32_t max_token_id = static_cast(std::distance( + p_logits, std::max_element(p_logits, p_logits + vocab_size))); + + int32_t n_text_ctx = model_->TextCtx(); + + std::vector predicted_tokens; + for (int32_t i = 0; i < n_text_ctx; ++i) { + if (max_token_id == model_->EOT()) { + break; + } + + predicted_tokens.push_back(max_token_id); + + std::array token_shape{1, 1}; + Ort::Value tokens = Ort::Value::CreateTensor( + model_->Allocator(), token_shape.data(), token_shape.size()); + int64_t *p_tokens = tokens.GetTensorMutableData(); + p_tokens[0] = max_token_id; + + int64_t *p_offset = + std::get<5>(decoder_out).GetTensorMutableData(); + + if (i == 0) { + *p_offset = initial_tokens.size(); + } else { + *p_offset += 1; + } + + decoder_out = model_->ForwardDecoder(std::move(tokens), + std::move(std::get<1>(decoder_out)), + std::move(std::get<2>(decoder_out)), + std::move(std::get<3>(decoder_out)), + std::move(std::get<4>(decoder_out)), + std::move(std::get<5>(decoder_out))); + + const auto &logits = std::get<0>(decoder_out); + const float *p_logits = logits.GetTensorData(); + + max_token_id = static_cast(std::distance( + p_logits, std::max_element(p_logits, p_logits + vocab_size))); + } + + std::vector ans(1); + ans[0].tokens = std::move(predicted_tokens); + + return ans; +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h new file mode 100644 index 000000000..98e515b9f --- /dev/null +++ b/sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h @@ -0,0 +1,29 @@ +// sherpa-onnx/csrc/offline-whisper-greedy-search-decoder.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_ + +#include + +#include "sherpa-onnx/csrc/offline-whisper-decoder.h" +#include "sherpa-onnx/csrc/offline-whisper-model.h" + +namespace sherpa_onnx { + +class OfflineWhisperGreedySearchDecoder : public OfflineWhisperDecoder { + public: + explicit OfflineWhisperGreedySearchDecoder(OfflineWhisperModel *model) + : model_(model) {} + + std::vector Decode(Ort::Value cross_k, + Ort::Value cross_v) override; + + private: + OfflineWhisperModel *model_; // not owned +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_GREEDY_SEARCH_DECODER_H_ diff --git a/sherpa-onnx/csrc/offline-whisper-model-config.cc b/sherpa-onnx/csrc/offline-whisper-model-config.cc new file mode 100644 index 000000000..1a469e672 --- /dev/null +++ b/sherpa-onnx/csrc/offline-whisper-model-config.cc @@ -0,0 +1,46 @@ +// sherpa-onnx/csrc/offline-whisper-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" + +#include "sherpa-onnx/csrc/file-utils.h" +#include "sherpa-onnx/csrc/macros.h" + +namespace sherpa_onnx { + +void OfflineWhisperModelConfig::Register(ParseOptions *po) { + po->Register("whisper-encoder", &encoder, + "Path to onnx encoder of whisper, e.g., tiny-encoder.onnx, " + "medium.en-encoder.onnx."); + + po->Register("whisper-decoder", &decoder, + "Path to onnx decoder of whisper, e.g., tiny-decoder.onnx, " + "medium.en-decoder.onnx."); +} + +bool OfflineWhisperModelConfig::Validate() const { + if (!FileExists(encoder)) { + SHERPA_ONNX_LOGE("whisper encoder file %s does not exist", encoder.c_str()); + return false; + } + + if (!FileExists(decoder)) { + SHERPA_ONNX_LOGE("whisper decoder file %s does not exist", decoder.c_str()); + return false; + } + + return true; +} + +std::string OfflineWhisperModelConfig::ToString() const { + std::ostringstream os; + + os << "OfflineWhisperModelConfig("; + os << "encoder=\"" << encoder << "\", "; + os << "decoder=\"" << decoder << "\")"; + + return os.str(); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-whisper-model-config.h b/sherpa-onnx/csrc/offline-whisper-model-config.h new file mode 100644 index 000000000..03e533726 --- /dev/null +++ b/sherpa-onnx/csrc/offline-whisper-model-config.h @@ -0,0 +1,30 @@ +// sherpa-onnx/csrc/offline-whisper-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ + +#include + +#include "sherpa-onnx/csrc/parse-options.h" + +namespace sherpa_onnx { + +struct OfflineWhisperModelConfig { + std::string encoder; + std::string decoder; + + OfflineWhisperModelConfig() = default; + OfflineWhisperModelConfig(const std::string &encoder, + const std::string &decoder) + : encoder(encoder), decoder(decoder) {} + + void Register(ParseOptions *po); + bool Validate() const; + + std::string ToString() const; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/csrc/offline-whisper-model.cc b/sherpa-onnx/csrc/offline-whisper-model.cc new file mode 100644 index 000000000..31739384d --- /dev/null +++ b/sherpa-onnx/csrc/offline-whisper-model.cc @@ -0,0 +1,213 @@ +// sherpa-onnx/csrc/offline-whisper-model.cc +// +// Copyright (c) 2022-2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-whisper-model.h" + +#include +#include +#include +#include + +#include "sherpa-onnx/csrc/macros.h" +#include "sherpa-onnx/csrc/onnx-utils.h" +#include "sherpa-onnx/csrc/session.h" +#include "sherpa-onnx/csrc/text-utils.h" + +namespace sherpa_onnx { + +class OfflineWhisperModel::Impl { + public: + explicit Impl(const OfflineModelConfig &config) + : config_(config), + env_(ORT_LOGGING_LEVEL_ERROR), + sess_opts_(GetSessionOptions(config)), + allocator_{} { + { + auto buf = ReadFile(config.whisper.encoder); + InitEncoder(buf.data(), buf.size()); + } + + { + auto buf = ReadFile(config.whisper.decoder); + InitDecoder(buf.data(), buf.size()); + } + } + + std::pair ForwardEncoder(Ort::Value features) { + auto encoder_out = encoder_sess_->Run( + {}, encoder_input_names_ptr_.data(), &features, 1, + encoder_output_names_ptr_.data(), encoder_output_names_ptr_.size()); + + return {std::move(encoder_out[0]), std::move(encoder_out[1])}; + } + + std::tuple + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, + Ort::Value n_layer_cross_v, Ort::Value offset) { + std::array decoder_input = {std::move(tokens), + std::move(n_layer_self_k_cache), + std::move(n_layer_self_v_cache), + std::move(n_layer_cross_k), + std::move(n_layer_cross_v), + std::move(offset)}; + + auto decoder_out = decoder_sess_->Run( + {}, decoder_input_names_ptr_.data(), decoder_input.data(), + decoder_input.size(), decoder_output_names_ptr_.data(), + decoder_output_names_ptr_.size()); + + return {std::move(decoder_out[0]), std::move(decoder_out[1]), + std::move(decoder_out[2]), std::move(decoder_input[3]), + std::move(decoder_input[4]), std::move(decoder_input[5])}; + } + + std::pair GetInitialSelfKVCache() { + std::array shape{n_text_layer_, 1, n_text_ctx_, n_text_state_}; + + Ort::Value n_layer_self_k_cache = Ort::Value::CreateTensor( + Allocator(), shape.data(), shape.size()); + + Ort::Value n_layer_self_v_cache = Ort::Value::CreateTensor( + Allocator(), shape.data(), shape.size()); + + auto n = shape[0] * shape[1] * shape[2] * shape[3]; + + float *p_k = n_layer_self_k_cache.GetTensorMutableData(); + float *p_v = n_layer_self_v_cache.GetTensorMutableData(); + + memset(p_k, 0, sizeof(float) * n); + memset(p_v, 0, sizeof(float) * n); + + return {std::move(n_layer_self_k_cache), std::move(n_layer_self_v_cache)}; + } + + OrtAllocator *Allocator() const { return allocator_; } + + const std::vector &GetInitialTokens() const { return sot_sequence_; } + + int32_t EOT() const { return eot_; } + + int32_t TextCtx() const { return n_text_ctx_; } + + private: + void InitEncoder(void *model_data, size_t model_data_length) { + encoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(encoder_sess_.get(), &encoder_input_names_, + &encoder_input_names_ptr_); + + GetOutputNames(encoder_sess_.get(), &encoder_output_names_, + &encoder_output_names_ptr_); + + // get meta data + Ort::ModelMetadata meta_data = encoder_sess_->GetModelMetadata(); + if (config_.debug) { + std::ostringstream os; + os << "---encoder---\n"; + PrintModelMetadata(os, meta_data); + SHERPA_ONNX_LOGE("%s\n", os.str().c_str()); + } + + Ort::AllocatorWithDefaultOptions allocator; // used in the macro below + SHERPA_ONNX_READ_META_DATA(n_text_layer_, "n_text_layer"); + SHERPA_ONNX_READ_META_DATA(n_text_ctx_, "n_text_ctx"); + SHERPA_ONNX_READ_META_DATA(n_text_state_, "n_text_state"); + SHERPA_ONNX_READ_META_DATA(sot_, "sot"); + SHERPA_ONNX_READ_META_DATA(eot_, "eot"); + SHERPA_ONNX_READ_META_DATA(blank_, "blank_id"); + SHERPA_ONNX_READ_META_DATA(translate_, "translate"); + SHERPA_ONNX_READ_META_DATA(no_timestamps_, "no_timestamps"); + SHERPA_ONNX_READ_META_DATA(no_speech_, "no_speech"); + SHERPA_ONNX_READ_META_DATA_VEC(sot_sequence_, "sot_sequence"); + } + + void InitDecoder(void *model_data, size_t model_data_length) { + decoder_sess_ = std::make_unique( + env_, model_data, model_data_length, sess_opts_); + + GetInputNames(decoder_sess_.get(), &decoder_input_names_, + &decoder_input_names_ptr_); + + GetOutputNames(decoder_sess_.get(), &decoder_output_names_, + &decoder_output_names_ptr_); + } + + private: + OfflineModelConfig config_; + Ort::Env env_; + Ort::SessionOptions sess_opts_; + Ort::AllocatorWithDefaultOptions allocator_; + + std::unique_ptr encoder_sess_; + std::unique_ptr decoder_sess_; + + std::vector encoder_input_names_; + std::vector encoder_input_names_ptr_; + + std::vector encoder_output_names_; + std::vector encoder_output_names_ptr_; + + std::vector decoder_input_names_; + std::vector decoder_input_names_ptr_; + + std::vector decoder_output_names_; + std::vector decoder_output_names_ptr_; + + // model meta data + int32_t n_text_layer_; + int32_t n_text_ctx_; + int32_t n_text_state_; + int32_t sot_; + int32_t eot_; + int32_t blank_; + int32_t translate_; + int32_t no_timestamps_; + int32_t no_speech_; + std::vector sot_sequence_; +}; + +OfflineWhisperModel::OfflineWhisperModel(const OfflineModelConfig &config) + : impl_(std::make_unique(config)) {} + +OfflineWhisperModel::~OfflineWhisperModel() = default; + +std::pair OfflineWhisperModel::ForwardEncoder( + Ort::Value features) { + return impl_->ForwardEncoder(std::move(features)); +} + +std::tuple +OfflineWhisperModel::ForwardDecoder(Ort::Value tokens, + Ort::Value n_layer_self_k_cache, + Ort::Value n_layer_self_v_cache, + Ort::Value n_layer_cross_k, + Ort::Value n_layer_cross_v, + Ort::Value offset) { + return impl_->ForwardDecoder( + std::move(tokens), std::move(n_layer_self_k_cache), + std::move(n_layer_self_v_cache), std::move(n_layer_cross_k), + std::move(n_layer_cross_v), std::move(offset)); +} + +std::pair OfflineWhisperModel::GetInitialSelfKVCache() { + return impl_->GetInitialSelfKVCache(); +} + +OrtAllocator *OfflineWhisperModel::Allocator() const { + return impl_->Allocator(); +} + +const std::vector &OfflineWhisperModel::GetInitialTokens() const { + return impl_->GetInitialTokens(); +} + +int32_t OfflineWhisperModel::EOT() const { return impl_->EOT(); } + +int32_t OfflineWhisperModel::TextCtx() const { return impl_->TextCtx(); } + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/offline-whisper-model.h b/sherpa-onnx/csrc/offline-whisper-model.h new file mode 100644 index 000000000..4353e42f8 --- /dev/null +++ b/sherpa-onnx/csrc/offline-whisper-model.h @@ -0,0 +1,85 @@ +// sherpa-onnx/csrc/offline-whisper-model.h +// +// Copyright (c) 2022-2023 Xiaomi Corporation +#ifndef SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ +#define SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ + +#include +#include +#include +#include + +#include "onnxruntime_cxx_api.h" // NOLINT +#include "sherpa-onnx/csrc/offline-model-config.h" + +namespace sherpa_onnx { + +class OfflineWhisperModel { + public: + explicit OfflineWhisperModel(const OfflineModelConfig &config); + ~OfflineWhisperModel(); + + /** Run the encoder model. + * + * @param features A tensor of shape (N, C, T). It is changed in-place. + * C is 80 and T is 3000. + * + * @return Return a pair containing: + * - n_layer_cross_k: A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state) + * - n_layer_cross_v: A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state) + */ + std::pair ForwardEncoder(Ort::Value features); + + /** Run the decoder model. + * + * @param tokens A int64 tensor of shape (N, num_words) + * @param n_layer_self_k_cache A 4-D tensor of shape + * (n_text_layer, N, n_text_ctx, n_text_state). + * @param n_layer_self_v_cache A 4-D tensor of shape + * (n_text_layer, N, n_text_ctx, n_text_state). + * @param n_layer_cross_k A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * @param n_layer_cross_v A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * @param offset A int64 tensor of shape (N,) + * + * @return Return a tuple containing 6 tensors: + * + * - logits A 3-D tensor of shape (N, num_words, vocab_size) + * - out_n_layer_self_k_cache Same shape as n_layer_self_k_cache + * - out_n_layer_self_v_cache Same shape as n_layer_self_v_cache + * - out_n_layer_cross_k Same as n_layer_cross_k + * - out_n_layer_cross_v Same as n_layer_cross_v + * - out_offset Same as offset + */ + std::tuple + ForwardDecoder(Ort::Value tokens, Ort::Value n_layer_self_k_cache, + Ort::Value n_layer_self_v_cache, Ort::Value n_layer_cross_k, + Ort::Value n_layer_cross_v, Ort::Value offset); + + /** Return the initial self kv cache in a pair + * - n_layer_self_k_cache A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + * - n_layer_self_v_cache A 4-D tensor of shape + * (n_text_layer, N, n_audio_ctx, n_text_state). + */ + std::pair GetInitialSelfKVCache(); + const std::vector &GetInitialTokens() const; + + /** Return an allocator for allocating memory + */ + OrtAllocator *Allocator() const; + int32_t EOT() const; + int32_t TextCtx() const; + + private: + class Impl; + std::unique_ptr impl_; +}; + +} // namespace sherpa_onnx + +#endif // SHERPA_ONNX_CSRC_OFFLINE_WHISPER_MODEL_H_ diff --git a/sherpa-onnx/csrc/sherpa-onnx-microphone-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-microphone-offline.cc index 9a1f58256..a587ffa44 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-microphone-offline.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-microphone-offline.cc @@ -98,11 +98,15 @@ This program uses non-streaming models with microphone for speech recognition. ./bin/sherpa-onnx-microphone-offline \ --tokens=/path/to/tokens.txt \ --paraformer=/path/to/model.onnx \ - --num-threads=2 \ - --decoding-method=greedy_search + --num-threads=1 -Default value for num_threads is 2. -Valid values for decoding_method: greedy_search. +(3) Whisper models + + ./bin/sherpa-onnx-microphone-offline \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 Please refer to https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html diff --git a/sherpa-onnx/csrc/sherpa-onnx-offline.cc b/sherpa-onnx/csrc/sherpa-onnx-offline.cc index 71c47e0d1..c51549af4 100644 --- a/sherpa-onnx/csrc/sherpa-onnx-offline.cc +++ b/sherpa-onnx/csrc/sherpa-onnx-offline.cc @@ -23,7 +23,7 @@ int main(int32_t argc, char *argv[]) { --encoder=/path/to/encoder.onnx \ --decoder=/path/to/decoder.onnx \ --joiner=/path/to/joiner.onnx \ - --num-threads=2 \ + --num-threads=1 \ --decoding-method=greedy_search \ /path/to/foo.wav [bar.wav foobar.wav ...] @@ -33,14 +33,22 @@ int main(int32_t argc, char *argv[]) { ./bin/sherpa-onnx-offline \ --tokens=/path/to/tokens.txt \ --paraformer=/path/to/model.onnx \ - --num-threads=2 \ + --num-threads=1 \ --decoding-method=greedy_search \ /path/to/foo.wav [bar.wav foobar.wav ...] +(3) Whisper models + + ./bin/sherpa-onnx-offline \ + --whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \ + --whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \ + --tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \ + --num-threads=1 \ + /path/to/foo.wav [bar.wav foobar.wav ...] + + Note: It supports decoding multiple files in batches -Default value for num_threads is 2. -Valid values for decoding_method: greedy_search. foo.wav should be of single channel, 16-bit PCM encoded wave file; its sampling rate can be arbitrary and does not need to be 16kHz. @@ -55,6 +63,7 @@ for a list of pre-trained models to download. po.Read(argc, argv); if (po.NumArgs() < 1) { + fprintf(stderr, "Error: Please provide at least 1 wave file.\n\n"); po.PrintUsage(); exit(EXIT_FAILURE); } diff --git a/sherpa-onnx/csrc/symbol-table.cc b/sherpa-onnx/csrc/symbol-table.cc index 6f18bdadc..692783b4e 100644 --- a/sherpa-onnx/csrc/symbol-table.cc +++ b/sherpa-onnx/csrc/symbol-table.cc @@ -9,6 +9,7 @@ #include #include +#include "sherpa-onnx/csrc/base64-decode.h" #include "sherpa-onnx/csrc/onnx-utils.h" #if __ANDROID_API__ >= 9 @@ -82,4 +83,12 @@ std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) { return os << symbol_table.ToString(); } +void SymbolTable::ApplyBase64Decode() { + sym2id_.clear(); + for (auto &p : id2sym_) { + p.second = Base64Decode(p.second); + sym2id_[p.second] = p.first; + } +} + } // namespace sherpa_onnx diff --git a/sherpa-onnx/csrc/symbol-table.h b/sherpa-onnx/csrc/symbol-table.h index 103e0f27e..7a83ab24f 100644 --- a/sherpa-onnx/csrc/symbol-table.h +++ b/sherpa-onnx/csrc/symbol-table.h @@ -45,6 +45,9 @@ class SymbolTable { /// Return true if there is a given symbol in the symbol table. bool contains(const std::string &sym) const; + // for tokens.txt from Whisper + void ApplyBase64Decode(); + private: void Init(std::istream &is); diff --git a/sherpa-onnx/python/csrc/CMakeLists.txt b/sherpa-onnx/python/csrc/CMakeLists.txt index ce62a36c9..b1d9db522 100644 --- a/sherpa-onnx/python/csrc/CMakeLists.txt +++ b/sherpa-onnx/python/csrc/CMakeLists.txt @@ -11,6 +11,7 @@ pybind11_add_module(_sherpa_onnx offline-recognizer.cc offline-stream.cc offline-transducer-model-config.cc + offline-whisper-model-config.cc online-lm-config.cc online-recognizer.cc online-stream.cc diff --git a/sherpa-onnx/python/csrc/offline-model-config.cc b/sherpa-onnx/python/csrc/offline-model-config.cc index 6665e85f7..cfec6f148 100644 --- a/sherpa-onnx/python/csrc/offline-model-config.cc +++ b/sherpa-onnx/python/csrc/offline-model-config.cc @@ -11,6 +11,7 @@ #include "sherpa-onnx/python/csrc/offline-nemo-enc-dec-ctc-model-config.h" #include "sherpa-onnx/python/csrc/offline-paraformer-model-config.h" #include "sherpa-onnx/python/csrc/offline-transducer-model-config.h" +#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" namespace sherpa_onnx { @@ -18,22 +19,25 @@ void PybindOfflineModelConfig(py::module *m) { PybindOfflineTransducerModelConfig(m); PybindOfflineParaformerModelConfig(m); PybindOfflineNemoEncDecCtcModelConfig(m); + PybindOfflineWhisperModelConfig(m); using PyClass = OfflineModelConfig; py::class_(*m, "OfflineModelConfig") - .def( - py::init(), - py::arg("transducer") = OfflineTransducerModelConfig(), - py::arg("paraformer") = OfflineParaformerModelConfig(), - py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), - py::arg("tokens"), py::arg("num_threads"), py::arg("debug") = false, - py::arg("provider") = "cpu", py::arg("model_type") = "") + .def(py::init(), + py::arg("transducer") = OfflineTransducerModelConfig(), + py::arg("paraformer") = OfflineParaformerModelConfig(), + py::arg("nemo_ctc") = OfflineNemoEncDecCtcModelConfig(), + py::arg("whisper") = OfflineWhisperModelConfig(), py::arg("tokens"), + py::arg("num_threads"), py::arg("debug") = false, + py::arg("provider") = "cpu", py::arg("model_type") = "") .def_readwrite("transducer", &PyClass::transducer) .def_readwrite("paraformer", &PyClass::paraformer) .def_readwrite("nemo_ctc", &PyClass::nemo_ctc) + .def_readwrite("whisper", &PyClass::whisper) .def_readwrite("tokens", &PyClass::tokens) .def_readwrite("num_threads", &PyClass::num_threads) .def_readwrite("debug", &PyClass::debug) diff --git a/sherpa-onnx/python/csrc/offline-whisper-model-config.cc b/sherpa-onnx/python/csrc/offline-whisper-model-config.cc new file mode 100644 index 000000000..274704927 --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-whisper-model-config.cc @@ -0,0 +1,24 @@ +// sherpa-onnx/python/csrc/offline-whisper-model-config.cc +// +// Copyright (c) 2023 Xiaomi Corporation + +#include "sherpa-onnx/csrc/offline-whisper-model-config.h" + +#include +#include + +#include "sherpa-onnx/python/csrc/offline-whisper-model-config.h" + +namespace sherpa_onnx { + +void PybindOfflineWhisperModelConfig(py::module *m) { + using PyClass = OfflineWhisperModelConfig; + py::class_(*m, "OfflineWhisperModelConfig") + .def(py::init(), + py::arg("encoder"), py::arg("decoder")) + .def_readwrite("encoder", &PyClass::encoder) + .def_readwrite("decoder", &PyClass::decoder) + .def("__str__", &PyClass::ToString); +} + +} // namespace sherpa_onnx diff --git a/sherpa-onnx/python/csrc/offline-whisper-model-config.h b/sherpa-onnx/python/csrc/offline-whisper-model-config.h new file mode 100644 index 000000000..0b240585a --- /dev/null +++ b/sherpa-onnx/python/csrc/offline-whisper-model-config.h @@ -0,0 +1,16 @@ +// sherpa-onnx/python/csrc/offline-whisper-model-config.h +// +// Copyright (c) 2023 Xiaomi Corporation + +#ifndef SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ +#define SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ + +#include "sherpa-onnx/python/csrc/sherpa-onnx.h" + +namespace sherpa_onnx { + +void PybindOfflineWhisperModelConfig(py::module *m); + +} + +#endif // SHERPA_ONNX_PYTHON_CSRC_OFFLINE_WHISPER_MODEL_CONFIG_H_ diff --git a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py index 32fad47a1..0312c01c5 100644 --- a/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py +++ b/sherpa-onnx/python/sherpa_onnx/offline_recognizer.py @@ -1,4 +1,5 @@ # Copyright (c) 2023 by manyeyes +# Copyright (c) 2023 Xiaomi Corporation from pathlib import Path from typing import List, Optional @@ -7,6 +8,7 @@ OfflineModelConfig, OfflineNemoEncDecCtcModelConfig, OfflineParaformerModelConfig, + OfflineWhisperModelConfig, ) from _sherpa_onnx import OfflineRecognizer as _Recognizer from _sherpa_onnx import ( @@ -69,7 +71,7 @@ def from_transducer( feature_dim: Dimension of the feature used to train the model. decoding_method: - Support only greedy_search for now. + Valid values: greedy_search, modified_beam_search. debug: True to show debug messages. provider: @@ -137,7 +139,7 @@ def from_paraformer( feature_dim: Dimension of the feature used to train the model. decoding_method: - Valid values are greedy_search, modified_beam_search. + Valid values are greedy_search. debug: True to show debug messages. provider: @@ -185,14 +187,14 @@ def from_nemo_ctc( English, etc. Args: + model: + Path to ``model.onnx``. tokens: Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two columns:: symbol integer_id - model: - Path to ``model.onnx``. num_threads: Number of threads for neural network computation. sample_rate: @@ -200,7 +202,7 @@ def from_nemo_ctc( feature_dim: Dimension of the feature used to train the model. decoding_method: - Valid values are greedy_search, modified_beam_search. + Valid values are greedy_search. debug: True to show debug messages. provider: @@ -229,6 +231,68 @@ def from_nemo_ctc( self.recognizer = _Recognizer(recognizer_config) return self + @classmethod + def from_whisper( + cls, + encoder: str, + decoder: str, + tokens: str, + num_threads: int, + decoding_method: str = "greedy_search", + debug: bool = False, + provider: str = "cpu", + ): + """ + Please refer to + ``_ + to download pre-trained models for different kinds of whisper models, + e.g., tiny, tiny.en, base, base.en, etc. + + Args: + encoder_model: + Path to the encoder model, e.g., tiny-encoder.onnx, + tiny-encoder.int8.onnx, tiny-encoder.ort, etc. + decoder_model: + Path to the encoder model, e.g., tiny-encoder.onnx, + tiny-encoder.int8.onnx, tiny-encoder.ort, etc. + tokens: + Path to ``tokens.txt``. Each line in ``tokens.txt`` contains two + columns:: + + symbol integer_id + + num_threads: + Number of threads for neural network computation. + decoding_method: + Valid values: greedy_search. + debug: + True to show debug messages. + provider: + onnxruntime execution providers. Valid values are: cpu, cuda, coreml. + """ + self = cls.__new__(cls) + model_config = OfflineModelConfig( + whisper=OfflineWhisperModelConfig(encoder=encoder, decoder=decoder), + tokens=tokens, + num_threads=num_threads, + debug=debug, + provider=provider, + model_type="whisper", + ) + + feat_config = OfflineFeatureExtractorConfig( + sampling_rate=16000, + feature_dim=80, + ) + + recognizer_config = OfflineRecognizerConfig( + feat_config=feat_config, + model_config=model_config, + decoding_method=decoding_method, + ) + self.recognizer = _Recognizer(recognizer_config) + return self + def create_stream(self, contexts_list: Optional[List[List[int]]] = None): if contexts_list is None: return self.recognizer.create_stream()