-
Notifications
You must be signed in to change notification settings - Fork 414
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Export NeMo FastConformer Hybrid Transducer-CTC Large Streaming to ON…
…NX. (#843)
- Loading branch information
1 parent
dbaa26f
commit a9f936e
Showing
5 changed files
with
431 additions
and
0 deletions.
There are no files selected for viewing
73 changes: 73 additions & 0 deletions
73
.github/workflows/export-nemo-fast-conformer-hybrid-transducer-ctc.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
name: export-nemo-speaker-verification-to-onnx | ||
|
||
on: | ||
workflow_dispatch: | ||
|
||
concurrency: | ||
group: export-nemo-fast-conformer-hybrid-transducer-ctc-to-onnx-${{ github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
export-nemo-fast-conformer-hybrid-transducer-ctc-to-onnx: | ||
if: github.repository_owner == 'k2-fsa' || github.repository_owner == 'csukuangfj' | ||
name: export NeMo fast conformer | ||
runs-on: ${{ matrix.os }} | ||
strategy: | ||
fail-fast: false | ||
matrix: | ||
os: [macos-latest] | ||
python-version: ["3.10"] | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
|
||
- name: Setup Python ${{ matrix.python-version }} | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: ${{ matrix.python-version }} | ||
|
||
- name: Install NeMo | ||
shell: bash | ||
run: | | ||
BRANCH='main' | ||
pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[asr] | ||
pip install onnxruntime | ||
pip install kaldi-native-fbank | ||
pip install soundfile librosa | ||
- name: Run | ||
shell: bash | ||
run: | | ||
cd scripts/nemo/fast-conformer-hybrid-transducer-ctc | ||
./run-ctc.sh | ||
mv -v sherpa-onnx-nemo* ../../.. | ||
- name: Download test waves | ||
shell: bash | ||
run: | | ||
mkdir test_wavs | ||
pushd test_wavs | ||
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/0.wav | ||
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/1.wav | ||
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/8k.wav | ||
curl -SL -O https://hf-mirror.com/csukuangfj/sherpa-onnx-nemo-ctc-en-conformer-small/resolve/main/test_wavs/trans.txt | ||
popd | ||
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms | ||
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms | ||
cp -av test_wavs ./sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms | ||
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-80ms | ||
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-480ms | ||
tar cjvf sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms.tar.bz2 sherpa-onnx-nemo-streaming-fast-conformer-ctc-1040ms | ||
- name: Release | ||
uses: svenstaro/upload-release-action@v2 | ||
with: | ||
file_glob: true | ||
file: ./*.tar.bz2 | ||
overwrite: true | ||
repo_name: k2-fsa/sherpa-onnx | ||
repo_token: ${{ secrets.UPLOAD_GH_SHERPA_ONNX_TOKEN }} | ||
tag: asr-models |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Introduction | ||
|
||
This folder contains scripts for exporting models from | ||
|
||
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_80ms | ||
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_480ms | ||
- https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_fastconformer_hybrid_large_streaming_1040ms | ||
|
||
to `sherpa-onnx`. |
117 changes: 117 additions & 0 deletions
117
scripts/nemo/fast-conformer-hybrid-transducer-ctc/export-onnx-ctc.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
#!/usr/bin/env python3 | ||
import argparse | ||
from typing import Dict | ||
|
||
import nemo.collections.asr as nemo_asr | ||
import onnx | ||
import torch | ||
|
||
|
||
def get_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--model", | ||
type=str, | ||
required=True, | ||
choices=["80", "480", "1040"], | ||
) | ||
return parser.parse_args() | ||
|
||
|
||
def add_meta_data(filename: str, meta_data: Dict[str, str]): | ||
"""Add meta data to an ONNX model. It is changed in-place. | ||
Args: | ||
filename: | ||
Filename of the ONNX model to be changed. | ||
meta_data: | ||
Key-value pairs. | ||
""" | ||
model = onnx.load(filename) | ||
while len(model.metadata_props): | ||
model.metadata_props.pop() | ||
|
||
for key, value in meta_data.items(): | ||
meta = model.metadata_props.add() | ||
meta.key = key | ||
meta.value = str(value) | ||
|
||
onnx.save(model, filename) | ||
|
||
|
||
@torch.no_grad() | ||
def main(): | ||
args = get_args() | ||
model_name = f"stt_en_fastconformer_hybrid_large_streaming_{args.model}ms" | ||
|
||
asr_model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) | ||
|
||
with open("./tokens.txt", "w", encoding="utf-8") as f: | ||
for i, s in enumerate(asr_model.joint.vocabulary): | ||
f.write(f"{s} {i}\n") | ||
f.write(f"<blk> {i+1}\n") | ||
print("Saved to tokens.txt") | ||
|
||
decoder_type = "ctc" | ||
asr_model.change_decoding_strategy(decoder_type=decoder_type) | ||
asr_model.eval() | ||
|
||
assert asr_model.encoder.streaming_cfg is not None | ||
if isinstance(asr_model.encoder.streaming_cfg.chunk_size, list): | ||
chunk_size = asr_model.encoder.streaming_cfg.chunk_size[1] | ||
else: | ||
chunk_size = asr_model.encoder.streaming_cfg.chunk_size | ||
|
||
if isinstance(asr_model.encoder.streaming_cfg.pre_encode_cache_size, list): | ||
pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size[1] | ||
else: | ||
pre_encode_cache_size = asr_model.encoder.streaming_cfg.pre_encode_cache_size | ||
window_size = chunk_size + pre_encode_cache_size | ||
|
||
print("chunk_size", chunk_size) | ||
print("pre_encode_cache_size", pre_encode_cache_size) | ||
print("window_size", window_size) | ||
|
||
chunk_shift = chunk_size | ||
|
||
# cache_last_channel: (batch_size, dim1, dim2, dim3) | ||
cache_last_channel_dim1 = len(asr_model.encoder.layers) | ||
cache_last_channel_dim2 = asr_model.encoder.streaming_cfg.last_channel_cache_size | ||
cache_last_channel_dim3 = asr_model.encoder.d_model | ||
|
||
# cache_last_time: (batch_size, dim1, dim2, dim3) | ||
cache_last_time_dim1 = len(asr_model.encoder.layers) | ||
cache_last_time_dim2 = asr_model.encoder.d_model | ||
cache_last_time_dim3 = asr_model.encoder.conv_context_size[0] | ||
|
||
asr_model.set_export_config({"decoder_type": "ctc", "cache_support": True}) | ||
|
||
filename = "model.onnx" | ||
|
||
asr_model.export(filename) | ||
|
||
meta_data = { | ||
"vocab_size": asr_model.decoder.vocab_size, | ||
"window_size": window_size, | ||
"chunk_shift": chunk_shift, | ||
"normalize_type": "None", | ||
"cache_last_channel_dim1": cache_last_channel_dim1, | ||
"cache_last_channel_dim2": cache_last_channel_dim2, | ||
"cache_last_channel_dim3": cache_last_channel_dim3, | ||
"cache_last_time_dim1": cache_last_time_dim1, | ||
"cache_last_time_dim2": cache_last_time_dim2, | ||
"cache_last_time_dim3": cache_last_time_dim3, | ||
"subsampling_factor": 8, | ||
"model_type": "EncDecHybridRNNTCTCBPEModel", | ||
"version": "1", | ||
"model_author": "NeMo", | ||
"url": f"https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/{model_name}", | ||
"comment": "Only the CTC branch is exported", | ||
} | ||
add_meta_data(filename, meta_data) | ||
|
||
print(meta_data) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
35 changes: 35 additions & 0 deletions
35
scripts/nemo/fast-conformer-hybrid-transducer-ctc/run-ctc.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -ex | ||
|
||
if [ ! -e ./0.wav ]; then | ||
# curl -SL -O https://hf-mirror.com/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav | ||
curl -SL -O https://huggingface.co/csukuangfj/icefall-asr-librispeech-streaming-zipformer-small-2024-03-18/resolve/main/test_wavs/0.wav | ||
fi | ||
|
||
ms=( | ||
80 | ||
480 | ||
1040 | ||
) | ||
|
||
for m in ${ms[@]}; do | ||
./export-onnx-ctc.py --model $m | ||
d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms | ||
if [ ! -f $d/model.onnx ]; then | ||
mkdir -p $d | ||
mv -v model.onnx $d/ | ||
mv -v tokens.txt $d/ | ||
ls -lh $d | ||
fi | ||
done | ||
|
||
# Now test the exported models | ||
|
||
for m in ${ms[@]}; do | ||
d=sherpa-onnx-nemo-streaming-fast-conformer-ctc-${m}ms | ||
python3 ./test-onnx-ctc.py \ | ||
--model $d/model.onnx \ | ||
--tokens $d/tokens.txt \ | ||
--wav ./0.wav | ||
done |
Oops, something went wrong.