Skip to content

Commit

Permalink
Support whisper models (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj authored Aug 7, 2023
1 parent 64efbd8 commit 45b9d4a
Show file tree
Hide file tree
Showing 39 changed files with 1,835 additions and 51 deletions.
63 changes: 63 additions & 0 deletions .github/workflows/export-whisper-to-onnx.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: export-whisper-to-onnx

on:
workflow_dispatch:

concurrency:
group: release-whisper-${{ github.ref }}
cancel-in-progress: true

jobs:
release-whisper-models:
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj'
name: ${{ matrix.model }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest]
model: ["tiny.en", "base.en", "small.en", "medium.en"]

steps:
- uses: actions/checkout@v2

- name: Install dependencies
shell: bash
run: |
python3 -m pip install openai-whisper torch onnxruntime onnx
- name: export ${{ matrix.model }}
shell: bash
run: |
cd scripts/whisper
python3 ./export-onnx.py --model ${{ matrix.model }}
python3 -m onnxruntime.tools.convert_onnx_models_to_ort --optimization_style=Fixed ./
ls -lh
ls -lh ~/.cache/whisper
- name: Publish ${{ matrix.model }} to huggingface
shell: bash
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
run: |
cd scripts/whisper
git config --global user.email "[email protected]"
git config --global user.name "Fangjun Kuang"
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} huggingface
cp *.onnx ./huggingface
cp *.ort ./huggingface
cp *tokens.txt ./huggingface
cd huggingface
git status
ls -lh
git lfs track "*.onnx"
git lfs track "*.ort"
git add .
git commit -m "upload ${{ matrix.model }}"
git push https://csukuangfj:[email protected]/csukuangfj/sherpa-onnx-whisper-${{ matrix.model }} main
4 changes: 2 additions & 2 deletions .github/workflows/run-java-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ on:
- 'sherpa-onnx/jni/*'

concurrency:
group: jni-${{ github.ref }}
group: run-java-test-${{ github.ref }}
cancel-in-progress: true

permissions:
contents: read

jobs:
jni:
run_java_test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
Expand Down
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.13 FATAL_ERROR)
project(sherpa-onnx)

set(SHERPA_ONNX_VERSION "1.5.5")
set(SHERPA_ONNX_VERSION "1.6.0")

# Disable warning about
#
Expand Down
16 changes: 8 additions & 8 deletions cmake/kaldi-native-fbank.cmake
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
function(download_kaldi_native_fbank)
include(FetchContent)

set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.17.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.17.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=300dc282d51d738e70f194ef13a50bf4cf8d54a3b2686d75f7fc2fb821f8c1e6")
set(kaldi_native_fbank_URL "https://github.com/csukuangfj/kaldi-native-fbank/archive/refs/tags/v1.18.1.tar.gz")
set(kaldi_native_fbank_URL2 "https://huggingface.co/csukuangfj/sherpa-onnx-cmake-deps/resolve/main/kaldi-native-fbank-1.18.1.tar.gz")
set(kaldi_native_fbank_HASH "SHA256=c7676f319fa97e8c8bca6018792de120895dcfe122fa9b4bff00f8f9165348e7")

set(KALDI_NATIVE_FBANK_BUILD_TESTS OFF CACHE BOOL "" FORCE)
set(KALDI_NATIVE_FBANK_BUILD_PYTHON OFF CACHE BOOL "" FORCE)
Expand All @@ -12,11 +12,11 @@ function(download_kaldi_native_fbank)
# If you don't have access to the Internet,
# please pre-download kaldi-native-fbank
set(possible_file_locations
$ENV{HOME}/Downloads/kaldi-native-fbank-1.17.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.17.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.17.tar.gz
/tmp/kaldi-native-fbank-1.17.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.17.tar.gz
$ENV{HOME}/Downloads/kaldi-native-fbank-1.18.1.tar.gz
${PROJECT_SOURCE_DIR}/kaldi-native-fbank-1.18.1.tar.gz
${PROJECT_BINARY_DIR}/kaldi-native-fbank-1.18.1.tar.gz
/tmp/kaldi-native-fbank-1.18.1.tar.gz
/star-fj/fangjun/download/github/kaldi-native-fbank-1.18.1.tar.gz
)

foreach(f IN LISTS possible_file_locations)
Expand Down
58 changes: 58 additions & 0 deletions python-api-examples/offline-decode-files.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3
#
# Copyright (c) 2023 by manyeyes
# Copyright (c) 2023 Xiaomi Corporation

"""
This file demonstrates how to use sherpa-onnx Python API to transcribe
Expand Down Expand Up @@ -34,6 +35,27 @@
(3) For CTC models from NeMo
python3 ./python-api-examples/offline-decode-files.py \
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
--num-threads=2 \
--decoding-method=greedy_search \
--debug=false \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
(4) For Whisper models
python3 ./python-api-examples/offline-decode-files.py \
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
--num-threads=1 \
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
Please refer to
https://k2-fsa.github.io/sherpa/onnx/index.html
to install sherpa-onnx and to download the pre-trained models
Expand Down Expand Up @@ -144,6 +166,20 @@ def get_args():
help="Number of threads for neural network computation",
)

parser.add_argument(
"--whisper-encoder",
default="",
type=str,
help="Path to whisper encoder model",
)

parser.add_argument(
"--whisper-decoder",
default="",
type=str,
help="Path to whisper decoder model",
)

parser.add_argument(
"--decoding-method",
type=str,
Expand Down Expand Up @@ -247,6 +283,8 @@ def main():
if args.encoder:
assert len(args.paraformer) == 0, args.paraformer
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder

contexts = [x.strip().upper() for x in args.contexts.split("/") if x.strip()]
if contexts:
Expand All @@ -271,6 +309,9 @@ def main():
)
elif args.paraformer:
assert len(args.nemo_ctc) == 0, args.nemo_ctc
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder

assert_file_exists(args.paraformer)

recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
Expand All @@ -283,6 +324,11 @@ def main():
debug=args.debug,
)
elif args.nemo_ctc:
assert len(args.whisper_encoder) == 0, args.whisper_encoder
assert len(args.whisper_decoder) == 0, args.whisper_decoder

assert_file_exists(args.nemo_ctc)

recognizer = sherpa_onnx.OfflineRecognizer.from_nemo_ctc(
model=args.nemo_ctc,
tokens=args.tokens,
Expand All @@ -292,6 +338,18 @@ def main():
decoding_method=args.decoding_method,
debug=args.debug,
)
elif args.whisper_encoder:
assert_file_exists(args.whisper_encoder)
assert_file_exists(args.whisper_decoder)

recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
encoder=args.whisper_encoder,
decoder=args.whisper_decoder,
tokens=args.tokens,
num_threads=args.num_threads,
decoding_method=args.decoding_method,
debug=args.debug,
)
else:
print("Please specify at least one model")
return
Expand Down
4 changes: 4 additions & 0 deletions scripts/whisper/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*.onnx
*.config
*.ort
*-tokens.txt
9 changes: 9 additions & 0 deletions scripts/whisper/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Introduction

This folder contains code showing how to convert [Whisper][whisper] to onnx
and use onnxruntime to replace PyTorch for speech recognition.

You can use [sherpa-onnx][sherpa-onnx] to run the converted model.

[whisper]: https://github.com/openai/whisper
[sherpa-onnx]: https://github.com/k2-fsa/sherpa-onnx
Loading

0 comments on commit 45b9d4a

Please sign in to comment.