Skip to content

Commit

Permalink
Support exporting LSTM with projection to ONNX (#621)
Browse files Browse the repository at this point in the history
* Support exporting LSTM with projection to ONNX

* Add missing files

* small fixes
  • Loading branch information
csukuangfj authored Oct 18, 2022
1 parent d1f16a0 commit d69bb82
Show file tree
Hide file tree
Showing 14 changed files with 1,002 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ for sym in 1 2 3; do
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$rep/test_wavs/BAC009S0764W0123.wav
$repo/test_wavs/BAC009S0764W0123.wav
done

for method in modified_beam_search beam_search fast_beam_search; do
Expand All @@ -55,7 +55,7 @@ for method in modified_beam_search beam_search fast_beam_search; do
--lang-dir $repo/data/lang_char \
$repo/test_wavs/BAC009S0764W0121.wav \
$repo/test_wavs/BAC009S0764W0122.wav \
$rep/test_wavs/BAC009S0764W0123.wav
$repo/test_wavs/BAC009S0764W0123.wav
done

echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,47 @@ log "Decode with models exported by torch.jit.trace()"
$repo/test_wavs/1221-135766-0001.wav \
$repo/test_wavs/1221-135766-0002.wav

log "Test exporting to ONNX"

./lstm_transducer_stateless2/export.py \
--exp-dir $repo/exp \
--bpe-model $repo/data/lang_bpe_500/bpe.model \
--epoch 99 \
--avg 1 \
--use-averaged-model 0 \
--onnx 1

log "Decode with ONNX models "

./lstm_transducer_stateless2/streaming-onnx-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo//exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1089-134686-0001.wav

./lstm_transducer_stateless2/streaming-onnx-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo//exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1221-135766-0001.wav

./lstm_transducer_stateless2/streaming-onnx-decode.py \
--bpe-model-filename $repo/data/lang_bpe_500/bpe.model \
--encoder-model-filename $repo//exp/encoder.onnx \
--decoder-model-filename $repo/exp/decoder.onnx \
--joiner-model-filename $repo/exp/joiner.onnx \
--joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \
--joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \
$repo/test_wavs/1221-135766-0002.wav



for sym in 1 2 3; do
log "Greedy search with --max-sym-per-frame $sym"

Expand Down Expand Up @@ -133,7 +174,7 @@ done

echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}"
echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}"
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"ncnn" ]]; then
if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then
mkdir -p lstm_transducer_stateless2/exp
ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt
ln -s $PWD/$repo/data/lang_bpe_500 data/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ cd egs/librispeech/ASR
repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29

log "Downloading pre-trained model from $repo_url"
git lfs install
git clone $repo_url
GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url
repo=$(basename $repo_url)

pushd $repo
git lfs pull --include "data/lang_bpe_500/bpe.model"
git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt"
popd

log "Display test files"
tree $repo/
soxi $repo/test_wavs/*.wav
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: run-librispeech-lstm-transducer-2022-09-03
name: run-librispeech-lstm-transducer2-2022-09-03

on:
push:
Expand All @@ -17,8 +17,8 @@ on:
- cron: "50 15 * * *"

jobs:
run_librispeech_pruned_transducer_stateless3_2022_05_13:
if: github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule'
run_librispeech_lstm_transducer_stateless2_2022_09_03:
if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule'
runs-on: ${{ matrix.os }}
strategy:
matrix:
Expand Down Expand Up @@ -110,7 +110,7 @@ jobs:
.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml
- name: Display decoding results for lstm_transducer_stateless2
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
if: github.event_name == 'schedule'
shell: bash
run: |
cd egs/librispeech/ASR
Expand All @@ -130,7 +130,7 @@ jobs:
- name: Upload decoding results for lstm_transducer_stateless2
uses: actions/upload-artifact@v2
if: github.event_name == 'schedule' || github.event.label.name == 'ncnn'
if: github.event_name == 'schedule'
with:
name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03
path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/
1 change: 1 addition & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py
Loading

0 comments on commit d69bb82

Please sign in to comment.