From d69bb826eddd6dda9988547b90dfc2a3af3edd01 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 18 Oct 2022 11:25:31 +0800 Subject: [PATCH] Support exporting LSTM with projection to ONNX (#621) * Support exporting LSTM with projection to ONNX * Add missing files * small fixes --- ...pruned-transducer-stateless3-2022-06-20.sh | 4 +- ...-lstm-transducer-stateless2-2022-09-03.yml | 43 +- ...pruned-transducer-stateless2-2022-04-29.sh | 8 +- ...-lstm-transducer-stateless2-2022-09-03.yml | 10 +- .../ASR/lstm_transducer_stateless/lstmp.py | 1 + .../ASR/lstm_transducer_stateless2/export.py | 277 +++++++++- .../ASR/lstm_transducer_stateless2/lstmp.py | 102 ++++ .../streaming-ncnn-decode.py | 3 +- .../streaming-onnx-decode.py | 478 ++++++++++++++++++ .../lstm_transducer_stateless2/test_lstmp.py | 70 +++ .../ASR/lstm_transducer_stateless3/lstmp.py | 1 + .../ASR/pruned_transducer_stateless3/lstmp.py | 1 + .../scaling_converter.py | 18 +- .../ASR/pruned_transducer_stateless2/lstmp.py | 1 + 14 files changed, 1002 insertions(+), 15 deletions(-) create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py create mode 100644 egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py create mode 100755 egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py create mode 100755 egs/librispeech/ASR/lstm_transducer_stateless2/test_lstmp.py create mode 120000 egs/librispeech/ASR/lstm_transducer_stateless3/lstmp.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless3/lstmp.py create mode 120000 egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh index aab2883a98..e70a1848d6 100755 --- a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -42,7 +42,7 @@ for sym in 1 2 3; do --lang-dir $repo/data/lang_char \ $repo/test_wavs/BAC009S0764W0121.wav \ $repo/test_wavs/BAC009S0764W0122.wav \ - $rep/test_wavs/BAC009S0764W0123.wav + $repo/test_wavs/BAC009S0764W0123.wav done for method in modified_beam_search beam_search fast_beam_search; do @@ -55,7 +55,7 @@ for method in modified_beam_search beam_search fast_beam_search; do --lang-dir $repo/data/lang_char \ $repo/test_wavs/BAC009S0764W0121.wav \ $repo/test_wavs/BAC009S0764W0122.wav \ - $rep/test_wavs/BAC009S0764W0123.wav + $repo/test_wavs/BAC009S0764W0123.wav done echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 3d57a895c1..b89055c728 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -105,6 +105,47 @@ log "Decode with models exported by torch.jit.trace()" $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav +log "Test exporting to ONNX" + +./lstm_transducer_stateless2/export.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + --onnx 1 + +log "Decode with ONNX models " + +./lstm_transducer_stateless2/streaming-onnx-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo//exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/1089-134686-0001.wav + +./lstm_transducer_stateless2/streaming-onnx-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo//exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/1221-135766-0001.wav + +./lstm_transducer_stateless2/streaming-onnx-decode.py \ + --bpe-model-filename $repo/data/lang_bpe_500/bpe.model \ + --encoder-model-filename $repo//exp/encoder.onnx \ + --decoder-model-filename $repo/exp/decoder.onnx \ + --joiner-model-filename $repo/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename $repo/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename $repo/exp/joiner_decoder_proj.onnx \ + $repo/test_wavs/1221-135766-0002.wav + + + for sym in 1 2 3; do log "Greedy search with --max-sym-per-frame $sym" @@ -133,7 +174,7 @@ done echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" -if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"ncnn" ]]; then +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then mkdir -p lstm_transducer_stateless2/exp ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt ln -s $PWD/$repo/data/lang_bpe_500 data/ diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh index d1e4a3991b..ae2bb68227 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless2-2022-04-29.sh @@ -13,10 +13,14 @@ cd egs/librispeech/ASR repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless2-2022-04-29 log "Downloading pre-trained model from $repo_url" -git lfs install -git clone $repo_url +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) +pushd $repo +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "exp/pretrained-epoch-38-avg-10.pt" +popd + log "Display test files" tree $repo/ soxi $repo/test_wavs/*.wav diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 9558b81529..dd67771bab 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -1,4 +1,4 @@ -name: run-librispeech-lstm-transducer-2022-09-03 +name: run-librispeech-lstm-transducer2-2022-09-03 on: push: @@ -17,8 +17,8 @@ on: - cron: "50 15 * * *" jobs: - run_librispeech_pruned_transducer_stateless3_2022_05_13: - if: github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule' + run_librispeech_lstm_transducer_stateless2_2022_09_03: + if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -110,7 +110,7 @@ jobs: .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml - name: Display decoding results for lstm_transducer_stateless2 - if: github.event_name == 'schedule' || github.event.label.name == 'ncnn' + if: github.event_name == 'schedule' shell: bash run: | cd egs/librispeech/ASR @@ -130,7 +130,7 @@ jobs: - name: Upload decoding results for lstm_transducer_stateless2 uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'ncnn' + if: github.event_name == 'schedule' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03 path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/ diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py new file mode 120000 index 0000000000..4f377cd010 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 8c4b243b06..1906736382 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -74,6 +74,29 @@ git lfs install git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03 # You will find the pre-trained models in icefall-asr-librispeech-lstm-transducer-stateless2-2022-09-03/exp + +(3) Export to ONNX format + +./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +It will generate the following files in the given `exp_dir`. + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + +Please see ./streaming-onnx-decode.py for usage of the generated files + +Check +https://github.com/k2-fsa/sherpa-onnx +for how to use the exported models outside of icefall. """ import argparse @@ -181,6 +204,23 @@ def get_parser(): """, ) + parser.add_argument( + "--onnx", + type=str2bool, + default=False, + help="""If True, --jit and --pnnx are ignored and it exports the model + to onnx format. It will generate the following files: + + - encoder.onnx + - decoder.onnx + - joiner.onnx + - joiner_encoder_proj.onnx + - joiner_decoder_proj.onnx + + Refer to ./onnx_check.py and ./onnx_pretrained.py for how to use them. + """, + ) + parser.add_argument( "--context-size", type=int, @@ -266,6 +306,215 @@ def export_joiner_model_jit_trace( logging.info(f"Saved to {joiner_filename}") +def export_encoder_model_onnx( + encoder_model: nn.Module, + encoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the given encoder model to ONNX format. + The exported model has 3 inputs: + + - x, a tensor of shape (N, T, C); dtype is torch.float32 + - x_lens, a tensor of shape (N,); dtype is torch.int64 + - states: a tuple containing: + - h0: a tensor of shape (num_layers, N, proj_size) + - c0: a tensor of shape (num_layers, N, hidden_size) + + and it has 3 outputs: + + - encoder_out, a tensor of shape (N, T, C) + - encoder_out_lens, a tensor of shape (N,) + - states: a tuple containing: + - next_h0: a tensor of shape (num_layers, N, proj_size) + - next_c0: a tensor of shape (num_layers, N, hidden_size) + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + N = 1 + x = torch.zeros(N, 9, 80, dtype=torch.float32) + x_lens = torch.tensor([9], dtype=torch.int64) + h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) + c = torch.rand( + encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size + ) + + warmup = 1.0 + torch.onnx.export( + encoder_model, # use torch.jit.trace() internally + (x, x_lens, (h, c), warmup), + encoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["x", "x_lens", "h", "c", "warmup"], + output_names=["encoder_out", "encoder_out_lens", "next_h", "next_c"], + dynamic_axes={ + "x": {0: "N", 1: "T"}, + "x_lens": {0: "N"}, + "h": {1: "N"}, + "c": {1: "N"}, + "encoder_out": {0: "N", 1: "T"}, + "encoder_out_lens": {0: "N"}, + "next_h": {1: "N"}, + "next_c": {1: "N"}, + }, + ) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_onnx( + decoder_model: nn.Module, + decoder_filename: str, + opset_version: int = 11, +) -> None: + """Export the decoder model to ONNX format. + + The exported model has one input: + + - y: a torch.int64 tensor of shape (N, decoder_model.context_size) + + and has one output: + + - decoder_out: a torch.float32 tensor of shape (N, 1, C) + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The decoder model to be exported. + decoder_filename: + Filename to save the exported ONNX model. + opset_version: + The opset version to use. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = False # Always False, so we can use torch.jit.trace() here + # Note(fangjun): torch.jit.trace() is more efficient than torch.jit.script() + # in this case + torch.onnx.export( + decoder_model, + (y, need_pad), + decoder_filename, + verbose=False, + opset_version=opset_version, + input_names=["y", "need_pad"], + output_names=["decoder_out"], + dynamic_axes={ + "y": {0: "N"}, + "decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_onnx( + joiner_model: nn.Module, + joiner_filename: str, + opset_version: int = 11, +) -> None: + """Export the joiner model to ONNX format. + The exported joiner model has two inputs: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + - projected_decoder_out: a tensor of shape (N, joiner_dim) + + and produces one output: + + - logit: a tensor of shape (N, vocab_size) + + The exported encoder_proj model has one input: + + - encoder_out: a tensor of shape (N, encoder_out_dim) + + and produces one output: + + - projected_encoder_out: a tensor of shape (N, joiner_dim) + + The exported decoder_proj model has one input: + + - decoder_out: a tensor of shape (N, decoder_out_dim) + + and produces one output: + + - projected_decoder_out: a tensor of shape (N, joiner_dim) + """ + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) + + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) + + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + joiner_dim = joiner_model.decoder_proj.weight.shape[0] + + projected_encoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + projected_decoder_out = torch.rand(1, joiner_dim, dtype=torch.float32) + + project_input = False + # Note: It uses torch.jit.trace() internally + torch.onnx.export( + joiner_model, + (projected_encoder_out, projected_decoder_out, project_input), + joiner_filename, + verbose=False, + opset_version=opset_version, + input_names=[ + "projected_encoder_out", + "projected_decoder_out", + "project_input", + ], + output_names=["logit"], + dynamic_axes={ + "projected_encoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + "logit": {0: "N"}, + }, + ) + logging.info(f"Saved to {joiner_filename}") + + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.encoder_proj, + encoder_out, + encoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["encoder_out"], + output_names=["projected_encoder_out"], + dynamic_axes={ + "encoder_out": {0: "N"}, + "projected_encoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {encoder_proj_filename}") + + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + torch.onnx.export( + joiner_model.decoder_proj, + decoder_out, + decoder_proj_filename, + verbose=False, + opset_version=opset_version, + input_names=["decoder_out"], + output_names=["projected_decoder_out"], + dynamic_axes={ + "decoder_out": {0: "N"}, + "projected_decoder_out": {0: "N"}, + }, + ) + logging.info(f"Saved to {decoder_proj_filename}") + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -387,7 +636,33 @@ def main(): model.to("cpu") model.eval() - if params.pnnx: + if params.onnx: + logging.info("Export model to ONNX format") + convert_scaled_to_non_scaled(model, inplace=True, is_onnx=True) + + opset_version = 11 + encoder_filename = params.exp_dir / "encoder.onnx" + export_encoder_model_onnx( + model.encoder, + encoder_filename, + opset_version=opset_version, + ) + + decoder_filename = params.exp_dir / "decoder.onnx" + export_decoder_model_onnx( + model.decoder, + decoder_filename, + opset_version=opset_version, + ) + + joiner_filename = params.exp_dir / "joiner.onnx" + export_joiner_model_onnx( + model.joiner, + joiner_filename, + opset_version=opset_version, + ) + + elif params.pnnx: convert_scaled_to_non_scaled(model, inplace=True) logging.info("Using torch.jit.trace()") encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py new file mode 100644 index 0000000000..dba6eb5209 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/lstmp.py @@ -0,0 +1,102 @@ +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LSTMP(nn.Module): + """LSTM with projection. + + PyTorch does not support exporting LSTM with projection to ONNX. + This class reimplements LSTM with projection using basic matrix-matrix + and matrix-vector operations. It is not intended for training. + """ + + def __init__(self, lstm: nn.LSTM): + """ + Args: + lstm: + LSTM with proj_size. We support only uni-directional, + 1-layer LSTM with projection at present. + """ + super().__init__() + assert lstm.bidirectional is False, lstm.bidirectional + assert lstm.num_layers == 1, lstm.num_layers + assert 0 < lstm.proj_size < lstm.hidden_size, ( + lstm.proj_size, + lstm.hidden_size, + ) + + assert lstm.batch_first is False, lstm.batch_first + + state_dict = lstm.state_dict() + + w_ih = state_dict["weight_ih_l0"] + w_hh = state_dict["weight_hh_l0"] + + b_ih = state_dict["bias_ih_l0"] + b_hh = state_dict["bias_hh_l0"] + + w_hr = state_dict["weight_hr_l0"] + self.input_size = lstm.input_size + self.proj_size = lstm.proj_size + self.hidden_size = lstm.hidden_size + + self.w_ih = w_ih + self.w_hh = w_hh + self.b = b_ih + b_hh + self.w_hr = w_hr + + def forward( + self, + input: torch.Tensor, + hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + input: + A tensor of shape [T, N, hidden_size] + hx: + A tuple containing: + - h0: a tensor of shape (1, N, proj_size) + - c0: a tensor of shape (1, N, hidden_size) + Returns: + Return a tuple containing: + - output: a tensor of shape (T, N, proj_size). + - A tuple containing: + - h: a tensor of shape (1, N, proj_size) + - c: a tensor of shape (1, N, hidden_size) + + """ + x_list = input.unbind(dim=0) # We use batch_first=False + + if hx is not None: + h0, c0 = hx + else: + h0 = torch.zeros(1, input.size(1), self.proj_size) + c0 = torch.zeros(1, input.size(1), self.hidden_size) + h0 = h0.squeeze(0) + c0 = c0.squeeze(0) + y_list = [] + for x in x_list: + gates = F.linear(x, self.w_ih, self.b) + F.linear(h0, self.w_hh) + i, f, g, o = gates.chunk(4, dim=1) + + i = i.sigmoid() + f = f.sigmoid() + g = g.tanh() + o = o.sigmoid() + + c = f * c0 + i * g + h = o * c.tanh() + + h = F.linear(h, self.w_hr) + y_list.append(h) + + c0 = c + h0 = h + + y = torch.stack(y_list, dim=0) + + return y, (h0.unsqueeze(0), c0.unsqueeze(0)) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index da78413f78..e47a05a9e4 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -233,13 +233,12 @@ def greedy_search( hyp, dtype=torch.int32 ) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) - else: assert decoder_out.ndim == 1 assert hyp is not None, hyp joiner_out = model.run_joiner(encoder_out, decoder_out) - y = joiner_out.argmax(dim=0).tolist() + y = joiner_out.argmax(dim=0).item() if y != blank_id: hyp.append(y) decoder_input = hyp[-context_size:] diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py new file mode 100755 index 0000000000..1c9ec3e898 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -0,0 +1,478 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads ONNX models and uses them to decode waves. +You can use the following command to get the exported models: + +./lstm_transducer_stateless2/export.py \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --onnx 1 + +Usage of this script: + +./lstm_transducer_stateless2/onnx-streaming-decode.py \ + --encoder-model-filename ./lstm_transducer_stateless2/exp/encoder.onnx \ + --decoder-model-filename ./lstm_transducer_stateless2/exp/decoder.onnx \ + --joiner-model-filename ./lstm_transducer_stateless2/exp/joiner.onnx \ + --joiner-encoder-proj-model-filename ./lstm_transducer_stateless2/exp/joiner_encoder_proj.onnx \ + --joiner-decoder-proj-model-filename ./lstm_transducer_stateless2/exp/joiner_decoder_proj.onnx \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +from typing import List, Optional, Tuple + +import onnxruntime as ort +import sentencepiece as spm +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--bpe-model-filename", + type=str, + help="Path to bpe.model", + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder onnx model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder onnx model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner onnx model. ", + ) + + parser.add_argument( + "--joiner-encoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner encoder_proj onnx model. ", + ) + + parser.add_argument( + "--joiner-decoder-proj-model-filename", + type=str, + required=True, + help="Path to the joiner decoder_proj onnx model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser.parse_args() + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +class Model: + def __init__(self, args): + session_opts = ort.SessionOptions() + session_opts.inter_op_num_threads = 5 + session_opts.intra_op_num_threads = 5 + self.session_opts = session_opts + + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + self.init_joiner_encoder_proj(args) + self.init_joiner_decoder_proj(args) + + def init_encoder(self, args): + self.encoder = ort.InferenceSession( + args.encoder_model_filename, + sess_options=self.session_opts, + ) + + def init_decoder(self, args): + self.decoder = ort.InferenceSession( + args.decoder_model_filename, + sess_options=self.session_opts, + ) + + def init_joiner(self, args): + self.joiner = ort.InferenceSession( + args.joiner_model_filename, + sess_options=self.session_opts, + ) + + def init_joiner_encoder_proj(self, args): + self.joiner_encoder_proj = ort.InferenceSession( + args.joiner_encoder_proj_model_filename, + sess_options=self.session_opts, + ) + + def init_joiner_decoder_proj(self, args): + self.joiner_decoder_proj = ort.InferenceSession( + args.joiner_decoder_proj_model_filename, + sess_options=self.session_opts, + ) + + def run_encoder( + self, x, h0, c0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A tensor of shape (N, T, C) + h0: + A tensor of shape (num_layers, N, proj_size) + c0: + A tensor of shape (num_layers, N, hidden_size) + Returns: + Return a tuple containing: + - encoder_out: A tensor of shape (N, T', C') + - next_h0: A tensor of shape (num_layers, N, proj_size) + - next_c0: A tensor of shape (num_layers, N, hidden_size) + """ + encoder_input_nodes = self.encoder.get_inputs() + encoder_out_nodes = self.encoder.get_outputs() + x_lens = torch.tensor([x.size(1)], dtype=torch.int64) + + encoder_out, encoder_out_lens, next_h0, next_c0 = self.encoder.run( + [ + encoder_out_nodes[0].name, + encoder_out_nodes[1].name, + encoder_out_nodes[2].name, + encoder_out_nodes[3].name, + ], + { + encoder_input_nodes[0].name: x.numpy(), + encoder_input_nodes[1].name: x_lens.numpy(), + encoder_input_nodes[2].name: h0.numpy(), + encoder_input_nodes[3].name: c0.numpy(), + }, + ) + return ( + torch.from_numpy(encoder_out), + torch.from_numpy(next_h0), + torch.from_numpy(next_c0), + ) + + def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: + """ + Args: + decoder_input: + A tensor of shape (N, context_size). Its dtype is torch.int64. + Returns: + Return a tensor of shape (N, 1, decoder_out_dim). + """ + decoder_input_nodes = self.decoder.get_inputs() + decoder_output_nodes = self.decoder.get_outputs() + + decoder_out = self.decoder.run( + [decoder_output_nodes[0].name], + { + decoder_input_nodes[0].name: decoder_input.numpy(), + }, + )[0] + + return self.run_joiner_decoder_proj( + torch.from_numpy(decoder_out).squeeze(1) + ) + + def run_joiner( + self, + projected_encoder_out: torch.Tensor, + projected_decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + projected_encoder_out: + A tensor of shape (N, joiner_dim) + projected_decoder_out: + A tensor of shape (N, joiner_dim) + Returns: + Return a tensor of shape (N, vocab_size) + """ + joiner_input_nodes = self.joiner.get_inputs() + joiner_output_nodes = self.joiner.get_outputs() + + logits = self.joiner.run( + [joiner_output_nodes[0].name], + { + joiner_input_nodes[0].name: projected_encoder_out.numpy(), + joiner_input_nodes[1].name: projected_decoder_out.numpy(), + }, + )[0] + + return torch.from_numpy(logits) + + def run_joiner_encoder_proj( + self, + encoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + encoder_out: + A tensor of shape (N, encoder_out_dim) + Returns: + A tensor of shape (N, joiner_dim) + """ + + projected_encoder_out = self.joiner_encoder_proj.run( + [self.joiner_encoder_proj.get_outputs()[0].name], + { + self.joiner_encoder_proj.get_inputs()[ + 0 + ].name: encoder_out.numpy() + }, + )[0] + + return torch.from_numpy(projected_encoder_out) + + def run_joiner_decoder_proj( + self, + decoder_out: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + decoder_out: + A tensor of shape (N, decoder_out_dim) + Returns: + A tensor of shape (N, joiner_dim) + """ + + projected_decoder_out = self.joiner_decoder_proj.run( + [self.joiner_decoder_proj.get_outputs()[0].name], + { + self.joiner_decoder_proj.get_inputs()[ + 0 + ].name: decoder_out.numpy() + }, + )[0] + + return torch.from_numpy(projected_decoder_out) + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: Model, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + assert encoder_out.shape[0] == 1, "TODO: support batch_size > 1" + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor( + [hyp], dtype=torch.int64 + ) # (1, context_size) + decoder_out = model.run_decoder(decoder_input) + else: + assert decoder_out.shape[0] == 1 + assert hyp is not None, hyp + + projected_encoder_out = model.run_joiner_encoder_proj(encoder_out) + + joiner_out = model.run_joiner(projected_encoder_out, decoder_out) + y = joiner_out.squeeze(0).argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor([decoder_input], dtype=torch.int64) + decoder_out = model.run_decoder(decoder_input) + + return hyp, decoder_out + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sound_file = args.sound_filename + sample_rate = 16000 + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model_filename) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + logging.info(wave_samples.shape) + + num_encoder_layers = 12 + batch_size = 1 + d_model = 512 + rnn_hidden_size = 1024 + + h0 = torch.zeros(num_encoder_layers, batch_size, d_model) + c0 = torch.zeros(num_encoder_layers, batch_size, rnn_hidden_size) + + hyp = None + decoder_out = None + + num_processed_frames = 0 + segment = 9 + offset = 4 + + chunk = 3200 # 0.2 second + + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + + num_processed_frames += offset + frames = torch.cat(frames, dim=0).unsqueeze(0) + encoder_out, h0, c0 = model.run_encoder(frames, h0, c0) + hyp, decoder_out = greedy_search( + model, encoder_out.squeeze(0), decoder_out, hyp + ) + online_fbank.accept_waveform( + sampling_rate=sample_rate, waveform=torch.zeros(5000, dtype=torch.float) + ) + + online_fbank.input_finished() + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0).unsqueeze(0) + encoder_out, h0, c0 = model.run_encoder(frames, h0, c0) + hyp, decoder_out = greedy_search( + model, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + + logging.info(sound_file) + logging.info(sp.decode(hyp[context_size:])) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/test_lstmp.py b/egs/librispeech/ASR/lstm_transducer_stateless2/test_lstmp.py new file mode 100755 index 0000000000..00ba224cdb --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/test_lstmp.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +from lstmp import LSTMP + + +def test(): + input_size = torch.randint(low=10, high=1024, size=(1,)).item() + hidden_size = torch.randint(low=10, high=1024, size=(1,)).item() + proj_size = hidden_size - 1 + lstm = nn.LSTM( + input_size=input_size, + hidden_size=hidden_size, + num_layers=1, + bias=True, + proj_size=proj_size, + ) + lstmp = LSTMP(lstm) + + N = torch.randint(low=1, high=10, size=(1,)).item() + T = torch.randint(low=1, high=20, size=(1,)).item() + x = torch.rand(T, N, input_size) + h0 = torch.rand(1, N, proj_size) + c0 = torch.rand(1, N, hidden_size) + + y1, (h1, c1) = lstm(x, (h0, c0)) + y2, (h2, c2) = lstmp(x, (h0, c0)) + + assert torch.allclose(y1, y2, atol=1e-5), (y1 - y2).abs().max() + assert torch.allclose(h1, h2, atol=1e-5), (h1 - h2).abs().max() + assert torch.allclose(c1, c2, atol=1e-5), (c1 - c2).abs().max() + + # lstm_script = torch.jit.script(lstm) # pytorch does not support it + lstm_script = lstm + lstmp_script = torch.jit.script(lstmp) + + y3, (h3, c3) = lstm_script(x, (h0, c0)) + y4, (h4, c4) = lstmp_script(x, (h0, c0)) + + assert torch.allclose(y3, y4, atol=1e-5), (y3 - y4).abs().max() + assert torch.allclose(h3, h4, atol=1e-5), (h3 - h4).abs().max() + assert torch.allclose(c3, c4, atol=1e-5), (c3 - c4).abs().max() + + assert torch.allclose(y3, y1, atol=1e-5), (y3 - y1).abs().max() + assert torch.allclose(h3, h1, atol=1e-5), (h3 - h1).abs().max() + assert torch.allclose(c3, c1, atol=1e-5), (c3 - c1).abs().max() + + lstm_trace = torch.jit.trace(lstm, (x, (h0, c0))) + lstmp_trace = torch.jit.trace(lstmp, (x, (h0, c0))) + + y5, (h5, c5) = lstm_trace(x, (h0, c0)) + y6, (h6, c6) = lstmp_trace(x, (h0, c0)) + + assert torch.allclose(y5, y6, atol=1e-5), (y5 - y6).abs().max() + assert torch.allclose(h5, h6, atol=1e-5), (h5 - h6).abs().max() + assert torch.allclose(c5, c6, atol=1e-5), (c5 - c6).abs().max() + + assert torch.allclose(y5, y1, atol=1e-5), (y5 - y1).abs().max() + assert torch.allclose(h5, h1, atol=1e-5), (h5 - h1).abs().max() + assert torch.allclose(c5, c1, atol=1e-5), (c5 - c1).abs().max() + + +@torch.no_grad() +def main(): + test() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstmp.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstmp.py new file mode 120000 index 0000000000..4f377cd010 --- /dev/null +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless3/lstmp.py new file mode 120000 index 0000000000..4f377cd010 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index f2f691eb1a..1e7e808c73 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -29,6 +29,7 @@ import torch import torch.nn as nn +from lstmp import LSTMP from scaling import ( ActivationBalancer, BasicNorm, @@ -259,7 +260,11 @@ def get_submodule(model, target): return mod -def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): +def convert_scaled_to_non_scaled( + model: nn.Module, + inplace: bool = False, + is_onnx: bool = False, +): """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, and `nn.Conv2d`. @@ -270,6 +275,9 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): inplace: If True, the input model is modified inplace. If False, the input model is copied and we modify the copied version. + is_onnx: + If True, we are going to export the model to ONNX. In this case, + we will convert nn.LSTM with proj_size to LSTMP. Return: Return a model without scaled layers. """ @@ -294,7 +302,13 @@ def convert_scaled_to_non_scaled(model: nn.Module, inplace: bool = False): elif isinstance(m, BasicNorm): d[name] = convert_basic_norm(m) elif isinstance(m, ScaledLSTM): - d[name] = scaled_lstm_to_lstm(m) + if is_onnx: + d[name] = LSTMP(scaled_lstm_to_lstm(m)) + # See + # https://github.com/pytorch/pytorch/issues/47887 + # d[name] = torch.jit.script(LSTMP(scaled_lstm_to_lstm(m))) + else: + d[name] = scaled_lstm_to_lstm(m) elif isinstance(m, ActivationBalancer): d[name] = nn.Identity() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py new file mode 120000 index 0000000000..b82e115fc8 --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file