diff --git a/.github/workflows/run-streaming-test.yaml b/.github/workflows/run-streaming-test.yaml new file mode 100644 index 000000000..7f88faf15 --- /dev/null +++ b/.github/workflows/run-streaming-test.yaml @@ -0,0 +1,118 @@ +# Copyright 2022 Xiaomi Corp. (author: Fangjun Kuang) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +name: Run streaming ASR tests + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + run_streaming_asr_tests: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-18.04, macos-10.15] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + python-version: [3.7, 3.8, 3.9] + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install GCC 7 + if: startsWith(matrix.os, 'ubuntu') + run: | + sudo apt-get install -y gcc-7 g++-7 + echo "CC=/usr/bin/gcc-7" >> $GITHUB_ENV + echo "CXX=/usr/bin/g++-7" >> $GITHUB_ENV + + - name: Install PyTorch ${{ matrix.torch }} + shell: bash + if: startsWith(matrix.os, 'ubuntu') + run: | + python3 -m pip install -qq --upgrade pip + python3 -m pip install -qq wheel twine typing_extensions websockets sentencepiece>=0.1.96 + python3 -m pip install -qq torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu numpy -f https://download.pytorch.org/whl/cpu/torch_stable.html + + - name: Install PyTorch ${{ matrix.torch }} + shell: bash + if: startsWith(matrix.os, 'macos') + run: | + python3 -m pip install -qq --upgrade pip + python3 -m pip install -qq wheel twine typing_extensions websockets sentencepiece>=0.1.96 + python3 -m pip install -qq torch==${{ matrix.torch }} torchaudio==${{ matrix.torchaudio }} numpy -f https://download.pytorch.org/whl/cpu/torch_stable.html + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-${{ matrix.os }} + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Install sherpa + shell: bash + run: | + python3 setup.py install + + - name: Download pretrained model and test-data + shell: bash + run: | + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01 + + - name: Start server + shell: bash + run: | + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + ./sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py \ + --port 6006 \ + --max-batch-size 50 \ + --max-wait-ms 5 \ + --nn-pool-size 1 \ + --nn-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/exp/cpu_jit-epoch-39-avg-6-use-averaged-model-1.pt \ + --bpe-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/data/lang_bpe_500/bpe.model & + + echo "Sleep 10 seconds to wait for the server startup" + sleep 10 + + - name: Start client + shell: bash + run: | + ./sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py \ + --server-addr localhost \ + --server-port 6006 \ + ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/test_wavs/1221-135766-0001.wav diff --git a/.github/workflows/run-test.yaml b/.github/workflows/run-test.yaml index 497b2b5c0..08afcb587 100644 --- a/.github/workflows/run-test.yaml +++ b/.github/workflows/run-test.yaml @@ -1,4 +1,3 @@ - # Copyright 2022 Xiaomi Corp. (author: Fangjun Kuang) # See ../../LICENSE for clarification regarding multiple authors diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml new file mode 100644 index 000000000..85c76b044 --- /dev/null +++ b/.github/workflows/style_check.yml @@ -0,0 +1,48 @@ +# Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: style_check + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + style_check: + runs-on: ubuntu-18.04 + strategy: + matrix: + python-version: [3.8] + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python-version }} + + - name: Check style with cpplint + shell: bash + working-directory: ${{github.workspace}} + run: ./scripts/check_style_cpplint.sh diff --git a/README.md b/README.md index f0917e223..271d8c694 100644 --- a/README.md +++ b/README.md @@ -1,24 +1,25 @@ ## Introduction -An ASR server framework in **Python**, aiming to support both streaming +An ASR server framework in **Python**, supporting both streaming and non-streaming recognition. -**Note**: Only non-streaming recognition is implemented at present. We -will add streaming recognition later. - CPU-bound tasks, such as neural network computation, are implemented in C++; while IO-bound tasks, such as socket communication, are implemented in Python. -**Caution**: We assume the model is trained using pruned stateless RNN-T -from [icefall][icefall] and it is from a directory like -`pruned_transducer_statelessX` where `X` >=2. +**Caution**: For offline ASR, we assume the model is trained using pruned +stateless RNN-T from [icefall][icefall] and it is from a directory like +`pruned_transducer_statelessX` where `X` >=2. For streaming ASR, we +assume the model is using `pruned_stateless_emformer_rnnt2`. -We provide a Colab notebook, containing how to start the server, how to -start the client, and how to decode `test-clean` of LibriSpeech. +For the offline ASR, we provide a Colab notebook, containing how to start the +server, how to start the client, and how to decode `test-clean` of LibriSpeech. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1JX5Ph2onYm1ZjNP_94eGqZ-DIRMLlIca?usp=sharing) +For the streaming ASR, we provide a YouTube demo, showing you how to use it. +See + ## Installation First, you have to install `PyTorch` and `torchaudio`. PyTorch 1.10 is known @@ -63,7 +64,6 @@ make -j export PYTHONPATH=$PWD/../sherpa/python:$PWD/lib:$PYTHONPATH ``` - ## Usage First, check that `sherpa` has been installed successfully: @@ -74,7 +74,103 @@ python3 -c "import sherpa; print(sherpa.__version__)" It should print the version of `sherpa`. -### Start the server +#### Streaming ASR with pruned stateless Emformer RNN-T + +#### Start the server + +To start the server, you need to first generate two files: + +- (1) The torch script model file. You can use `export.py --jit=1` in +`pruned_stateless_emformer_rnnt2` from [icefall][icefall]. + +- (2) The BPE model file. You can find it in `data/lang_bpe_XXX/bpe.model` +in [icefall][icefall], where `XXX` is the number of BPE tokens used in +the training. + +With the above two files ready, you can start the server with the +following command: + +```bash +./sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py \ + --port 6006 \ + --max-batch-size 50 \ + --max-wait-ms 5 \ + --nn-pool-size 1 \ + --nn-model-filename ./path/to/exp/cpu_jit.pt \ + --bpe-model-filename ./path/to/data/lang_bpe_500/bpe.model +``` + +You can use `./sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py --help` +to view the help message. + +We provide a pretrained model using the LibriSpeech dataset at + + +The following shows how to use the above pretrained model to start the server. + +```bash +git lfs install +git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01 + +./sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py \ + --port 6006 \ + --max-batch-size 50 \ + --max-wait-ms 5 \ + --nn-pool-size 1 \ + --nn-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/exp/cpu_jit-epoch-39-avg-6-use-averaged-model-1.pt \ + --bpe-model-filename ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/data/lang_bpe_500/bpe.model +``` + +#### Start the client + +We provide two clients at present: + + - (1) [./sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py](./sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py) + It shows how to decode a single sound file. + + - (2) [./sherpa/bin/pruned_stateless_emformer_rnnt2/web](./sherpa/bin/pruned_stateless_emformer_rnnt2/web) + You can record your speech in real-time within a browser and send it to the server for recognition. + +##### streaming_client.py + +```bash +./sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py --help + +./sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py \ + --server-addr localhost \ + --server-port 6006 \ + ./icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/test_wavs/1221-135766-0001.wav +``` + +##### Web client + +```bash +cd ./sherpa/bin/pruned_stateless_emformer_rnnt2/web +python3 -m http.server 6008 +``` + +Then open your browser and go to `http://localhost:6008/record.html`. You will +see a UI like the following screenshot. + +![web client screenshot](./pic/emformer-streaming-asr-web-client.png) + +Click the button `Record`. + +Now you can `speak` and you will get recognition results from the +server in real-time. + +**Caution**: For the web client, we hard-code the server port to `6006`. +You can change the file [./sherpa/bin/pruned_stateless_emformer_rnnt2/web/record.js](./sherpa/bin/pruned_stateless_emformer_rnnt2/web/record.js) +to replace `6006` in it to whatever port the server is using. + +**Caution**: `http://0.0.0.0:6008/record.html` or `http://127.0.0.1:6008/record.html` +won't work. You have to use `localhost`. Otherwise, you won't be able to use +your microphone in your browser since we are not using `https` which requires +a certificate. + +### Offline ASR + +#### Start the server To start the server, you need to first generate two files: @@ -97,7 +193,7 @@ sherpa/bin/offline_server.py \ --feature-extractor-pool-size 5 \ --nn-pool-size 1 \ --nn-model-filename ./path/to/exp/cpu_jit.pt \ - --bpe-model-filename ./path/to/data/lang_bpe_500/bpe.model & + --bpe-model-filename ./path/to/data/lang_bpe_500/bpe.model ``` You can use `./sherpa/bin/offline_server.py --help` to view the help message. @@ -122,7 +218,7 @@ sherpa/bin/offline_server.py \ --bpe-model-filename ./icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/data/lang_bpe_500/bpe.model ``` -### Start the client +#### Start the client After starting the server, you can use the following command to start the client: ```bash @@ -147,7 +243,7 @@ sherpa/bin/offline_client.py \ icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13//test_wavs/1221-135766-0002.wav ``` -### RTF test +#### RTF test We provide a demo [./sherpa/bin/decode_manifest.py](./sherpa/bin/decode_manifest.py) to decode the `test-clean` dataset from the LibriSpeech corpus. diff --git a/pic/emformer-streaming-asr-web-client.png b/pic/emformer-streaming-asr-web-client.png new file mode 100644 index 000000000..f0bf4b3f6 Binary files /dev/null and b/pic/emformer-streaming-asr-web-client.png differ diff --git a/scripts/check_style_cpplint.sh b/scripts/check_style_cpplint.sh new file mode 100755 index 000000000..2c72c2760 --- /dev/null +++ b/scripts/check_style_cpplint.sh @@ -0,0 +1,126 @@ +#!/bin/bash +# +# Copyright 2020 Mobvoi Inc. (authors: Fangjun Kuang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Usage: +# +# (1) To check files of the last commit +# ./scripts/check_style_cpplint.sh +# +# (2) To check changed files not committed yet +# ./scripts/check_style_cpplint.sh 1 +# +# (3) To check all files in the project +# ./scripts/check_style_cpplint.sh 2 + + +cpplint_version="1.5.4" +cur_dir=$(cd $(dirname $BASH_SOURCE) && pwd) +sherpa_dir=$(cd $cur_dir/.. && pwd) + +build_dir=$sherpa_dir/build +mkdir -p $build_dir + +cpplint_src=$build_dir/cpplint-${cpplint_version}/cpplint.py + +if [ ! -d "$build_dir/cpplint-${cpplint_version}" ]; then + pushd $build_dir + if command -v wget &> /dev/null; then + wget https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz + elif command -v curl &> /dev/null; then + curl -O -SL https://github.com/cpplint/cpplint/archive/${cpplint_version}.tar.gz + else + echo "Please install wget or curl to download cpplint" + exit 1 + fi + tar xf ${cpplint_version}.tar.gz + rm ${cpplint_version}.tar.gz + + # cpplint will report the following error for: __host__ __device__ ( + # + # Extra space before ( in function call [whitespace/parens] [4] + # + # the following patch disables the above error + sed -i "3490i\ not Search(r'__host__ __device__\\\s+\\\(', fncall) and" $cpplint_src + popd +fi + +source $sherpa_dir/scripts/utils.sh + +# return true if the given file is a c++ source file +# return false otherwise +function is_source_code_file() { + case "$1" in + *.cc|*.h|*.cu) + echo true;; + *) + echo false;; + esac +} + +function check_style() { + python3 $cpplint_src $1 || abort $1 +} + +function check_last_commit() { + files=$(git diff HEAD^1 --name-only --diff-filter=ACDMRUXB) + echo $files +} + +function check_current_dir() { + files=$(git status -s -uno --porcelain | awk '{ + if (NF == 4) { + # a file has been renamed + print $NF + } else { + print $2 + }}') + + echo $files +} + +function do_check() { + case "$1" in + 1) + echo "Check changed files" + files=$(check_current_dir) + ;; + 2) + echo "Check all files" + files=$(find $sherpa_dir/sherpa -name "*.h" -o -name "*.cc" -o -name "*.cu") + ;; + *) + echo "Check last commit" + files=$(check_last_commit) + ;; + esac + + for f in $files; do + need_check=$(is_source_code_file $f) + if $need_check; then + [[ -f $f ]] && check_style $f + fi + done +} + +function main() { + do_check $1 + + ok "Great! Style check passed!" +} + +cd $sherpa_dir + +main $1 diff --git a/scripts/utils.sh b/scripts/utils.sh new file mode 100644 index 000000000..fb424a7b8 --- /dev/null +++ b/scripts/utils.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +default='\033[0m' +bold='\033[1m' +red='\033[31m' +green='\033[32m' + +function ok() { + printf "${bold}${green}[OK]${default} $1\n" +} + +function error() { + printf "${bold}${red}[FAILED]${default} $1\n" +} + +function abort() { + printf "${bold}${red}[FAILED]${default} $1\n" + exit 1 +} diff --git a/sherpa/bin/README.md b/sherpa/bin/README.md new file mode 100644 index 000000000..e7ad66b7b --- /dev/null +++ b/sherpa/bin/README.md @@ -0,0 +1,35 @@ +# File descriptions + +## pruned_transducer_statelessX + +Files in the part assume the model is from `pruned_transducer_statelessX` in +the folder +where `X>=2`. + +| Filename | Description | +|----------|-------------| +| [offline_server.py](./offline_server.py) | The server for offline ASR | +| [offline_client.py](./offline_client.py) | The client for offline ASR | +| [decode_manifest.py](./decode_manifest.py) | Demo for computing RTF and WER| + +If you want to test the offline server without training your own model, you +can download pretrained models on the LibriSpeech corpus by visiting +. +There you can find links to various pretrained models. + +For instance, you can use + +## pruned_stateless_emformer_rnnt2 + +Files in the part assume the model is from `pruned_stateless_emformer_rnnt2` in +the folder . + +| Filename | Description | +|----------|-------------| +| [pruned_stateless_emformer_rnnt2/streaming_server.py](./pruned_stateless_emformer_rnnt2/streaming_server.py) | The server for streaming ASR | +| [pruned_stateless_emformer_rnnt2/streaming_client.py](./pruned_stateless_emformer_rnnt2/streaming_client.py) | The client for offline ASR | +| [pruned_stateless_emformer_rnnt2/decode.py](./pruned_stateless_emformer_rnnt2/decode.py) | Utilities for streaming ASR| + +You can use the pretrained model from + +to test it. diff --git a/sherpa/bin/decode_manifest.py b/sherpa/bin/decode_manifest.py index c06b94511..5aa668fbd 100755 --- a/sherpa/bin/decode_manifest.py +++ b/sherpa/bin/decode_manifest.py @@ -98,7 +98,7 @@ async def send( samples = c.load_audio().reshape(-1).astype(np.float32) num_bytes = samples.nbytes - await websocket.send((num_bytes).to_bytes(8, "big", signed=True)) + await websocket.send((num_bytes).to_bytes(8, "little", signed=True)) frame_size = (2 ** 20) // 4 # max payload is 1MB start = 0 diff --git a/sherpa/bin/offline_client.py b/sherpa/bin/offline_client.py index 3c4ba2749..37697c5f4 100755 --- a/sherpa/bin/offline_client.py +++ b/sherpa/bin/offline_client.py @@ -81,7 +81,7 @@ async def main(): wave = wave.squeeze(0) num_bytes = wave.numel() * wave.element_size() - await websocket.send((num_bytes).to_bytes(8, "big", signed=True)) + await websocket.send((num_bytes).to_bytes(8, "little", signed=True)) frame_size = (2 ** 20) // 4 # max payload is 1MB start = 0 diff --git a/sherpa/bin/offline_server.py b/sherpa/bin/offline_server.py index ee940bd84..56e316ada 100755 --- a/sherpa/bin/offline_server.py +++ b/sherpa/bin/offline_server.py @@ -21,6 +21,8 @@ the same time. Usage: + ./offline_server.py --help + ./offline_server.py """ @@ -78,7 +80,7 @@ def get_args(): default=25, help="""Max batch size for computation. Note if there are not enough requests in the queue, it will wait for max_wait_ms time. After that, - even if there are still not enough requests, it still sends the + even if there are not enough requests, it still sends the available requests in the queue for computation. """, ) @@ -137,7 +139,7 @@ def get_args(): def run_model_and_do_greedy_search( - model: torch.jit.ScriptModule, + model: RnntModel, features: List[torch.Tensor], ) -> List[List[int]]: """Run RNN-T model with the given features and use greedy search @@ -287,7 +289,7 @@ def _build_nn_model( return ans - async def loop(self, port: int): + async def run(self, port: int): logging.info("started") task = asyncio.create_task(self.feature_consumer_task()) @@ -309,7 +311,7 @@ async def recv_audio_samples( The message from the client has the following format: - a header of 8 bytes, containing the number of bytes of the tensor. - The header is in big endian format. + The header is in little endian format. - a binary representation of the 1-D torch.float32 tensor. Args: @@ -323,7 +325,7 @@ async def recv_audio_samples( async for message in socket: if expected_num_bytes is None: assert len(message) >= 8, (len(message), message) - expected_num_bytes = int.from_bytes(message[:8], "big", signed=True) + expected_num_bytes = int.from_bytes(message[:8], "little", signed=True) received += message[8:] if len(received) == expected_num_bytes: break @@ -459,7 +461,7 @@ def main(): feature_extractor_pool_size=feature_extractor_pool_size, nn_pool_size=nn_pool_size, ) - asyncio.run(offline_server.loop(port)) + asyncio.run(offline_server.run(port)) torch.set_num_threads(1) diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/decode.py b/sherpa/bin/pruned_stateless_emformer_rnnt2/decode.py new file mode 100644 index 000000000..e5c94d30d --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/decode.py @@ -0,0 +1,209 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List + +import torch +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def unstack_states( + states: List[List[torch.Tensor]], +) -> List[List[List[torch.Tensor]]]: + """Unstack the Emformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state for the i-th + utterance in the batch. + + Args: + states: + A list-of-list of tensors. ``len(states)`` equals to number of + layers in the Emformer. ``states[i]`` contains the states for + the i-th layer. ``states[i][k]`` is either a 3-D tensor of shape + ``(T, N, C)`` or a 2-D tensor of shape ``(C, N)`` + Returns: + Return the states for each utterance. ans[i] is the state for the i-th + utterance. Note that the returned state does not contain the batch + dimension. + """ + batch_size = states[0][0].size(1) + num_layers = len(states) + + ans = [None] * batch_size + for i in range(batch_size): + ans[i] = [[] for _ in range(num_layers)] + + for li, layer in enumerate(states): + for s in layer: + s_list = s.unbind(dim=1) + # We will use stack(dim=1) later in stack_states() + for bi, b in enumerate(ans): + b[li].append(s_list[bi]) + return ans + + +def stack_states( + state_list: List[List[List[torch.Tensor]]], +) -> List[List[torch.Tensor]]: + """Stack list of Emformer states that correspond to separate utterances + into a single Emformer state so that it can be used as an input for + Emformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponds to the internal state + of the Emformer model for a single utterance. + Returns: + Return a new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + ans = [] + for layer in state_list[0]: + # layer is a list of tensors + if batch_size > 1: + ans.append([[s] for s in layer]) + # Note: We will stack ans[layer][s][] later to get ans[layer][s] + else: + ans.append([s.unsqueeze(1) for s in layer]) + + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states): + for si, s in enumerate(layer): + ans[li][si].append(s) + if b == batch_size - 1: + ans[li][si] = torch.stack(ans[li][si], dim=1) + # We will use unbind(dim=1) later in unstack_states() + return ans + + +def _create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +class Stream(object): + def __init__( + self, + context_size: int, + blank_id: int, + initial_states: List[List[torch.Tensor]], + decoder_out: torch.Tensor, + ) -> None: + """ + Args: + context_size: + Context size of the RNN-T decoder model. + blank_id: + Blank token ID of the BPE model. + initial_states: + The initial states of the Emformer model. Note that the state + does not contain the batch dimension. + decoder_out: + The initial decoder out corresponding to the decoder input + `[blank_id]*context_size` + """ + self.feature_extractor = _create_streaming_feature_extractor() + # It contains a list of 2-D tensors representing the feature frames. + # Each entry is of shape (1, feature_dim) + self.features: List[torch.Tensor] = [] + self.num_fetched_frames = 0 + + self.states = initial_states + self.decoder_out = decoder_out + + self.context_size = context_size + self.hyp = [blank_id] * context_size + self.log_eps = math.log(1e-10) + + def accept_waveform( + self, + sampling_rate: float, + waveform: torch.Tensor, + ) -> None: + """Feed audio samples to the feature extractor and compute features + if there are enough samples available. + + Caution: + The range of the audio samples should match the one used in the + training. That is, if you use the range [-1, 1] in the training, then + the input audio samples should also be normalized to [-1, 1]. + + Args + sampling_rate: + The sampling rate of the input audio samples. It is used for sanity + check to ensure that the input sampling rate equals to the one + used in the extractor. If they are not equal, then no resampling + will be performed; instead an error will be thrown. + waveform: + A 1-D torch tensor of dtype torch.float32 containing audio samples. + It should be on CPU. + """ + self.feature_extractor.accept_waveform( + sampling_rate=sampling_rate, + waveform=waveform, + ) + self._fetch_frames() + + def input_finished(self) -> None: + """Signal that no more audio samples available and the feature + extractor should flush the buffered samples to compute frames. + """ + self.feature_extractor.input_finished() + self._fetch_frames() + + def _fetch_frames(self) -> None: + """Fetch frames from the feature extractor""" + while self.num_fetched_frames < self.feature_extractor.num_frames_ready: + frame = self.feature_extractor.get_frame(self.num_fetched_frames) + self.features.append(frame) + self.num_fetched_frames += 1 + + def add_tail_paddings(self, n: int = 20) -> None: + """Add some tail paddings so that we have enough context to process + frames at the very end of an utterance. + + Args: + n: + Number of tail padding frames to be added. You can increase it if + it happens that there are many missing tokens for the last word of + an utterance. + """ + tail_padding = torch.full( + (1, self.feature_extractor.opts.mel_opts.num_bins), + fill_value=self.log_eps, + dtype=torch.float32, + ) + + self.features += [tail_padding] * n diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py b/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py new file mode 100755 index 000000000..695a7a629 --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_client.py @@ -0,0 +1,121 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A client for streaming ASR recognition. + +Usage: + ./streaming_client.py \ + --server-addr localhost \ + --server-port 6006 \ + /path/to/foo.wav + +(Note: You have to first start the server before starting the client) +""" +import argparse +import asyncio +import logging + +import torchaudio +import websockets + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--server-addr", + type=str, + default="localhost", + help="Address of the server", + ) + + parser.add_argument( + "--server-port", + type=int, + default=6006, + help="Port of the server", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser.parse_args() + + +async def receive_results(socket: websockets.WebSocketServerProtocol): + partial_result = "" + async for message in socket: + if message == "Done": + break + partial_result = message + logging.info(f"Partial result: {partial_result}") + + return partial_result + + +async def main(): + args = get_args() + + server_addr = args.server_addr + server_port = args.server_port + test_wav = args.sound_file + + async with websockets.connect(f"ws://{server_addr}:{server_port}") as websocket: + logging.info(f"Sending {test_wav}") + wave, sample_rate = torchaudio.load(test_wav) + assert sample_rate == 16000, sample_rate + + receive_task = asyncio.create_task(receive_results(websocket)) + + wave = wave.squeeze(0) + + chunk_size = 4096 + start = 0 + while start < wave.numel(): + end = start + chunk_size + d = wave.numpy().data[start:end] + + num_bytes = d.nbytes + await websocket.send((num_bytes).to_bytes(8, "little", signed=True)) + + await websocket.send(d) + + start = end + + s = b"Done" + await websocket.send((len(s)).to_bytes(8, "little", signed=True)) + await websocket.send(s) + + logging.info("Send done") + + decoding_results = await receive_task + logging.info(f"{test_wav}\n{decoding_results}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + asyncio.run(main()) diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py b/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py new file mode 100755 index 000000000..be77cf268 --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/streaming_server.py @@ -0,0 +1,477 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +A server for streaming ASR recognition. By streaming it means the audio samples +are coming in real-time. You don't need to wait until all audio samples are +captured before sending them for recognition. + +It supports multiple clients sending at the same time. + +Usage: + ./streaming_server.py --help + + ./streaming_server.py +""" + +import argparse +import asyncio +import logging +import math +import warnings +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Tuple + +import sentencepiece as spm +import torch +import websockets +from sherpa import RnntEmformerModel, streaming_greedy_search + +from decode import Stream, stack_states, unstack_states + +DEFAULT_NN_MODEL_FILENAME = "/ceph-fj/fangjun/open-source-2/icefall-streaming-2/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/exp-full/cpu_jit-epoch-39-avg-6-use-averaged-model-1.pt" # noqa +DEFAULT_BPE_MODEL_FILENAME = "/ceph-fj/fangjun/open-source-2/icefall-streaming-2/egs/librispeech/ASR/data/lang_bpe_500/bpe.model" # noqa + +TEST_WAV = "/ceph-fj/fangjun/open-source-2/icefall-models/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1089-134686-0001.wav" +TEST_WAV = "/ceph-fj/fangjun/open-source-2/icefall-models/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0001.wav" +# TEST_WAV = "/ceph-fj/fangjun/open-source-2/icefall-models/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13/test_wavs/1221-135766-0002.wav" + + +def get_args(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--port", + type=int, + default=6006, + help="The server will listen on this port", + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + default=DEFAULT_NN_MODEL_FILENAME, + help="""The torchscript model. You can use + icefall/egs/librispeech/ASR/pruned_transducer_statelessX/export.py --jit=1 + to generate this model. + """, + ) + + parser.add_argument( + "--bpe-model-filename", + type=str, + default=DEFAULT_BPE_MODEL_FILENAME, + help="""The BPE model + You can find it in the directory egs/librispeech/ASR/data/lang_bpe_xxx + where xxx is the number of BPE tokens you used to train the model. + """, + ) + + parser.add_argument( + "--nn-pool-size", + type=int, + default=1, + help="Number of threads for NN computation and decoding.", + ) + + parser.add_argument( + "--max-batch-size", + type=int, + default=50, + help="""Max batch size for computation. Note if there are not enough + requests in the queue, it will wait for max_wait_ms time. After that, + even if there are not enough requests, it still sends the + available requests in the queue for computation. + """, + ) + + parser.add_argument( + "--max-wait-ms", + type=float, + default=10, + help="""Max time in millisecond to wait to build batches for inference. + If there are not enough requests in the stream queue to build a batch + of max_batch_size, it waits up to this time before fetching available + requests for computation. + """, + ) + + return parser.parse_args() + + +def run_model_and_do_greedy_search( + server: "StreamingServer", + stream_list: List[Stream], +) -> None: + """Run the model on the given stream list and do greedy search. + Args: + server: + An instance of `StreamingServer`. + stream_list: + A list of streams to be processed. It is changed in-place. + That is, the attribute `states`, `decoder_out`, and `hyp` are + updated in-place. + """ + model = server.model + device = model.device + segment_length = server.segment_length + chunk_length = server.chunk_length + + batch_size = len(stream_list) + + state_list = [] + decoder_out_list = [] + hyp_list = [] + feature_list = [] + for s in stream_list: + state_list.append(s.states) + decoder_out_list.append(s.decoder_out) + hyp_list.append(s.hyp) + + f = s.features[:chunk_length] + s.features = s.features[segment_length:] + + b = torch.cat(f, dim=0) + feature_list.append(b) + + features = torch.stack(feature_list, dim=0).to(device) + states = stack_states(state_list) + decoder_out = torch.cat(decoder_out_list, dim=0) + + features_length = torch.full( + (batch_size,), + fill_value=features.size(1), + device=device, + ) + + (encoder_out, next_states) = model.encoder_streaming_forward( + features, + features_length, + states, + ) + + # Note: It does not return the next_encoder_out_len since + # there are no paddings for streaming ASR. Each stream + # has the same input number of frames, i.e., server.chunk_length. + next_decoder_out, next_hyp_list = streaming_greedy_search( + model=model, + encoder_out=encoder_out, + decoder_out=decoder_out, + hyps=hyp_list, + ) + + next_state_list = unstack_states(next_states) + next_decoder_out_list = next_decoder_out.split(1) + for i, s in enumerate(stream_list): + s.states = next_state_list[i] + s.decoder_out = next_decoder_out_list[i] + s.hyp = next_hyp_list[i] + + +class StreamingServer(object): + def __init__( + self, + nn_model_filename: str, + bpe_model_filename: str, + nn_pool_size: int, + max_wait_ms: float, + max_batch_size: int, + ): + """ + Args: + nn_model_filename: + Path to the torchscript model + bpe_model_filename: + Path to the BPE model + nn_pool_size: + Number of threads for the thread pool that is responsible for + neural network computation and decoding. + max_wait_ms: + Max wait time in milliseconds in order to build a batch of + `batch_size`. + max_batch_size: + Max batch size for inference. + """ + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + else: + device = torch.device("cpu") + + self.model = RnntEmformerModel(nn_model_filename, device=device) + + # number of frames before subsampling + self.segment_length = self.model.segment_length + + self.right_context_length = self.model.right_context_length + + # We add 3 here since the subsampling method is using + # ((len - 1) // 2 - 1) // 2) + self.chunk_length = (self.segment_length + 3) + self.right_context_length + + self.sp = spm.SentencePieceProcessor() + self.sp.load(bpe_model_filename) + + self.context_size = self.model.context_size + self.blank_id = self.model.blank_id + self.log_eps = math.log(1e-10) + + initial_states = self.model.get_encoder_init_states() + self.initial_states = unstack_states(initial_states)[0] + decoder_input = torch.tensor( + [[self.blank_id] * self.context_size], + device=device, + dtype=torch.int64, + ) + self.initial_decoder_out = self.model.decoder_forward(decoder_input).squeeze(1) + + self.nn_pool = ThreadPoolExecutor( + max_workers=nn_pool_size, + thread_name_prefix="nn", + ) + + self.stream_queue = asyncio.Queue() + self.max_wait_ms = max_wait_ms + self.max_batch_size = max_batch_size + + async def stream_consumer_task(self): + """The function extract streams from the queue, batches them up, sends + them to the RNN-T model for computation and decoding. + """ + while True: + if self.stream_queue.empty(): + await asyncio.sleep(self.max_wait_ms / 1000) + continue + + batch = [] + try: + while len(batch) < self.max_batch_size: + item = self.stream_queue.get_nowait() + + assert len(item[0].features) >= self.chunk_length, len( + item[0].features + ) + + batch.append(item) + except asyncio.QueueEmpty: + pass + stream_list = [b[0] for b in batch] + future_list = [b[1] for b in batch] + + loop = asyncio.get_running_loop() + await loop.run_in_executor( + self.nn_pool, + run_model_and_do_greedy_search, + self, + stream_list, + ) + + for f in future_list: + self.stream_queue.task_done() + f.set_result(None) + + async def compute_and_decode( + self, + stream: Stream, + ) -> None: + """Put the stream into the queue and wait it to be processed by the + consumer task. + + Args: + stream: + The stream to be processed. Note: It is changed in-place. + """ + loop = asyncio.get_running_loop() + future = loop.create_future() + await self.stream_queue.put((stream, future)) + await future + + async def run(self, port: int): + task = asyncio.create_task(self.stream_consumer_task()) + + async with websockets.serve(self.handle_connection, "", port): + await asyncio.Future() # run forever + + async def handle_connection( + self, + socket: websockets.WebSocketServerProtocol, + ): + """Receive audio samples from the client, process it, and sends + deocoding result back to the client. + + Args: + socket: + The socket for communicating with the client. + """ + logging.info(f"Connected: {socket.remote_address}") + stream = Stream( + context_size=self.context_size, + blank_id=self.blank_id, + initial_states=self.initial_states, + decoder_out=self.initial_decoder_out, + ) + + last = b"" + while True: + samples, last = await self.recv_audio_samples(socket, last) + if samples is None: + break + + # TODO(fangjun): At present, we assume the sampling rate + # of the received audio samples is always 16000. + stream.accept_waveform(sampling_rate=16000, waveform=samples) + + while len(stream.features) > self.chunk_length: + await self.compute_and_decode(stream) + await socket.send(f"{self.sp.decode(stream.hyp[self.context_size:])}") + + stream.input_finished() + while len(stream.features) > self.chunk_length: + await self.compute_and_decode(stream) + + if len(stream.features) > 0: + n = self.chunk_length - len(stream.features) + stream.add_tail_paddings(n) + await self.compute_and_decode(stream) + stream.features = [] + + result = self.sp.decode(stream.hyp[self.context_size :]) + await socket.send(result) + await socket.send("Done") + + logging.info(f"Disconnected: {socket.remote_address}") + + async def recv_audio_samples( + self, + socket: websockets.WebSocketServerProtocol, + last: Optional[bytes] = None, + ) -> Tuple[Optional[torch.Tensor], Optional[bytes]]: + """Receives a tensor from the client. + + The message from the client contains two parts: header and payload + + - the header contains 8 bytes in little endian format, specifying + the number of bytes in the payload. + + - the payload contains either a binary representation of the 1-D + torch.float32 tensor or the bytes object b"Done" which means + the end of utterance. + + Args: + socket: + The socket for communicating with the client. + last: + Previous received content. + Returns: + Return a tuple containing: + - A 1-D torch.float32 tensor containing the audio samples + - Data for the next chunk, if any + or return a tuple (None, None) meaning the end of utterance. + """ + header_len = 8 + + if last is None: + last = b"" + + async def receive_header(): + buf = last + async for message in socket: + buf += message + if len(buf) >= header_len: + break + if buf: + header = buf[:header_len] + remaining = buf[header_len:] + else: + header = None + remaining = None + + return header, remaining + + header, received = await receive_header() + + if header is None: + return None, None + + expected_num_bytes = int.from_bytes(header, "little", signed=True) + + async for message in socket: + received += message + if len(received) >= expected_num_bytes: + break + + if not received or received == b"Done": + return None, None + + this_chunk = received[:expected_num_bytes] + next_chunk = received[expected_num_bytes:] + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # PyTorch warns that the underlying buffer is not writable. + # We ignore it here as we are not going to write it anyway. + return torch.frombuffer(this_chunk, dtype=torch.float32), next_chunk + + +@torch.no_grad() +def main(): + args = get_args() + + logging.info(vars(args)) + + port = args.port + nn_model_filename = args.nn_model_filename + bpe_model_filename = args.bpe_model_filename + nn_pool_size = args.nn_pool_size + max_batch_size = args.max_batch_size + max_wait_ms = args.max_wait_ms + + server = StreamingServer( + nn_model_filename=nn_model_filename, + bpe_model_filename=bpe_model_filename, + nn_pool_size=nn_pool_size, + max_batch_size=max_batch_size, + max_wait_ms=max_wait_ms, + ) + asyncio.run(server.run(port)) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# See https://github.com/pytorch/pytorch/issues/38342 +# and https://github.com/pytorch/pytorch/issues/33354 +# +# If we don't do this, the delay increases whenever there is +# a new request that changes the actual batch size. +# If you use `py-spy dump --pid --native`, you will +# see a lot of time is spent in re-compiling the torch script model. +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +""" +// Use the following in C++ +torch::jit::getExecutorMode() = false; +torch::jit::getProfilingMode() = false; +torch::jit::setGraphExecutorOptimize(false); +""" + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/web/index.html b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/index.html new file mode 100644 index 000000000..d0fec4fc1 --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/index.html @@ -0,0 +1,62 @@ + + + + + + + + + + + + + + Next-gen Kaldi demo + + + + + + + +
    +
  • +
    +
    Upload
    +

    Recognition from a selected file

    +
    +
  • + +
  • +
    +
    Record
    +

    Recognition from real-time recordings

    +
    +
  • +
+ + Code is available at + https://github.com/k2-fsa/icefall/tree/streaming/egs/librispeech/ASR/transducer_emformer + + + + + + + + + + diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/web/nav-partial.html b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/nav-partial.html new file mode 100644 index 000000000..513c1511f --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/nav-partial.html @@ -0,0 +1,22 @@ + diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/web/record.html b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/record.html new file mode 100644 index 000000000..4a06e0ec9 --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/record.html @@ -0,0 +1,71 @@ + + + + + + + + + + + + + + Next-gen Kaldi demo (Upload file for recognition) + + + + + + + +

Recognition from real-time recordings

+
+
+
+ +
+
+
+
+ +
+
+ +
+
+
+ +
+ + +
+ + + +
+
+ + + + + + + + + + + diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/web/record.js b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/record.js new file mode 100644 index 000000000..123fffd89 --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/record.js @@ -0,0 +1,343 @@ +// This file copies and modifies code +// from https://mdn.github.io/web-dictaphone/scripts/app.js +// and https://gist.github.com/meziantou/edb7217fddfbb70e899e + +var socket; +function initWebSocket() { + socket = new WebSocket('ws://localhost:6006/'); + + // Connection opened + socket.addEventListener('open', function(event) { + console.log('connected'); + document.getElementById('record').disabled = false; + }); + + // Connection closed + socket.addEventListener('close', function(event) { + console.log('disconnected'); + document.getElementById('record').disabled = true; + initWebSocket(); + }); + + // Listen for messages + socket.addEventListener('message', function(event) { + document.getElementById('results').innerHTML = event.data; + console.log('Received message: ', event.data); + }); +} + +const recordBtn = document.getElementById('record'); +const stopBtn = document.getElementById('stop'); +const clearBtn = document.getElementById('clear'); +const soundClips = document.getElementById('sound-clips'); +const canvas = document.getElementById('canvas'); +const mainSection = document.querySelector('.container'); + +stopBtn.disabled = true; + +let audioCtx; +const canvasCtx = canvas.getContext('2d'); +let mediaStream; +let analyser; + +let expectedSampleRate = 16000; +let recordSampleRate; // the sampleRate of the microphone +let recorder = null; // the microphone +let leftchannel = []; // TODO: Use a single channel + +let recordingLength = 0; // number of samples so far + +clearBtn.onclick = function() { + document.getElementById('results').innerHTML = ''; +}; + +// copied/modified from https://mdn.github.io/web-dictaphone/ +// and +// https://gist.github.com/meziantou/edb7217fddfbb70e899e +if (navigator.mediaDevices.getUserMedia) { + console.log('getUserMedia supported.'); + + // see https://w3c.github.io/mediacapture-main/#dom-mediadevices-getusermedia + const constraints = {audio: true}; + + let onSuccess = function(stream) { + if (!audioCtx) { + audioCtx = new AudioContext(); + } + console.log(audioCtx); + recordSampleRate = audioCtx.sampleRate; + console.log('sample rate ' + recordSampleRate); + + // creates an audio node from the microphone incoming stream + mediaStream = audioCtx.createMediaStreamSource(stream); + console.log(mediaStream); + + // https://developer.mozilla.org/en-US/docs/Web/API/AudioContext/createScriptProcessor + // bufferSize: the onaudioprocess event is called when the buffer is full + var bufferSize = 2048; + var numberOfInputChannels = 2; + var numberOfOutputChannels = 2; + if (audioCtx.createScriptProcessor) { + recorder = audioCtx.createScriptProcessor( + bufferSize, numberOfInputChannels, numberOfOutputChannels); + } else { + recorder = audioCtx.createJavaScriptNode( + bufferSize, numberOfInputChannels, numberOfOutputChannels); + } + console.log(recorder); + + recorder.onaudioprocess = function(e) { + let samples = new Float32Array(e.inputBuffer.getChannelData(0)) + samples = downsampleBuffer(samples, expectedSampleRate); + + let buf = new Int16Array(samples.length); + for (var i = 0; i < samples.length; ++i) { + let s = samples[i]; + if (s >= 1) + s = 1; + else if (s <= -1) + s = -1; + + samples[i] = s; + buf[i] = s * 32767; + } + + const header = new ArrayBuffer(8); + new DataView(header).setInt32( + 0, samples.byteLength, true /* littleEndian */); + + socket.send(new BigInt64Array(header, 0, 1)); + socket.send(samples); + + leftchannel.push(buf); + recordingLength += bufferSize; + }; + + visualize(stream); + mediaStream.connect(analyser); + + recordBtn.onclick = function() { + mediaStream.connect(recorder); + mediaStream.connect(analyser); + recorder.connect(audioCtx.destination); + + console.log('recorder started'); + recordBtn.style.background = 'red'; + + stopBtn.disabled = false; + recordBtn.disabled = true; + }; + + stopBtn.onclick = function() { + console.log('recorder stopped'); + socket.close(); + + // stopBtn recording + recorder.disconnect(audioCtx.destination); + mediaStream.disconnect(recorder); + mediaStream.disconnect(analyser); + + recordBtn.style.background = ''; + recordBtn.style.color = ''; + // mediaRecorder.requestData(); + + stopBtn.disabled = true; + recordBtn.disabled = false; + + const clipName = + prompt('Enter a name for your sound clip?', 'My unnamed clip'); + + const clipContainer = document.createElement('article'); + const clipLabel = document.createElement('p'); + const audio = document.createElement('audio'); + const deleteButton = document.createElement('button'); + clipContainer.classList.add('clip'); + audio.setAttribute('controls', ''); + deleteButton.textContent = 'Delete'; + deleteButton.className = 'delete'; + + if (clipName === null) { + clipLabel.textContent = 'My unnamed clip'; + } else { + clipLabel.textContent = clipName; + } + + clipContainer.appendChild(audio); + + clipContainer.appendChild(clipLabel); + clipContainer.appendChild(deleteButton); + soundClips.appendChild(clipContainer); + + audio.controls = true; + let samples = flatten(leftchannel); + const blob = toWav(samples); + + leftchannel = []; + const audioURL = window.URL.createObjectURL(blob); + audio.src = audioURL; + console.log('recorder stopped'); + + deleteButton.onclick = function(e) { + let evtTgt = e.target; + evtTgt.parentNode.parentNode.removeChild(evtTgt.parentNode); + }; + + clipLabel.onclick = function() { + const existingName = clipLabel.textContent; + const newClipName = prompt('Enter a new name for your sound clip?'); + if (newClipName === null) { + clipLabel.textContent = existingName; + } else { + clipLabel.textContent = newClipName; + } + }; + }; + }; + + let onError = function(err) { + console.log('The following error occured: ' + err); + }; + + navigator.mediaDevices.getUserMedia(constraints).then(onSuccess, onError); +} else { + console.log('getUserMedia not supported on your browser!'); + alert('getUserMedia not supported on your browser!'); +} + +function visualize(stream) { + if (!audioCtx) { + audioCtx = new AudioContext(); + } + + const source = audioCtx.createMediaStreamSource(stream); + + if (!analyser) { + analyser = audioCtx.createAnalyser(); + analyser.fftSize = 2048; + } + const bufferLength = analyser.frequencyBinCount; + const dataArray = new Uint8Array(bufferLength); + + // source.connect(analyser); + // analyser.connect(audioCtx.destination); + + draw() + + function draw() { + const WIDTH = canvas.width + const HEIGHT = canvas.height; + + requestAnimationFrame(draw); + + analyser.getByteTimeDomainData(dataArray); + + canvasCtx.fillStyle = 'rgb(200, 200, 200)'; + canvasCtx.fillRect(0, 0, WIDTH, HEIGHT); + + canvasCtx.lineWidth = 2; + canvasCtx.strokeStyle = 'rgb(0, 0, 0)'; + + canvasCtx.beginPath(); + + let sliceWidth = WIDTH * 1.0 / bufferLength; + let x = 0; + + for (let i = 0; i < bufferLength; i++) { + let v = dataArray[i] / 128.0; + let y = v * HEIGHT / 2; + + if (i === 0) { + canvasCtx.moveTo(x, y); + } else { + canvasCtx.lineTo(x, y); + } + + x += sliceWidth; + } + + canvasCtx.lineTo(canvas.width, canvas.height / 2); + canvasCtx.stroke(); + } +} + +window.onresize = function() { + canvas.width = mainSection.offsetWidth; +}; + +window.onresize(); + +// this function is copied/modified from +// https://gist.github.com/meziantou/edb7217fddfbb70e899e +function flatten(listOfSamples) { + let n = 0; + for (let i = 0; i < listOfSamples.length; ++i) { + n += listOfSamples[i].length; + } + let ans = new Int16Array(n); + + let offset = 0; + for (let i = 0; i < listOfSamples.length; ++i) { + ans.set(listOfSamples[i], offset); + offset += listOfSamples[i].length; + } + return ans; +} + +// this function is copied/modified from +// https://gist.github.com/meziantou/edb7217fddfbb70e899e +function toWav(samples) { + let buf = new ArrayBuffer(44 + samples.length * 2); + var view = new DataView(buf); + + // http://soundfile.sapp.org/doc/WaveFormat/ + // F F I R + view.setUint32(0, 0x46464952, true); // chunkID + view.setUint32(4, 36 + samples.length * 2, true); // chunkSize + // E V A W + view.setUint32(8, 0x45564157, true); // format + // + // t m f + view.setUint32(12, 0x20746d66, true); // subchunk1ID + view.setUint32(16, 16, true); // subchunk1Size, 16 for PCM + view.setUint32(20, 1, true); // audioFormat, 1 for PCM + view.setUint16(22, 1, true); // numChannels: 1 channel + view.setUint32(24, expectedSampleRate, true); // sampleRate + view.setUint32(28, expectedSampleRate * 2, true); // byteRate + view.setUint16(32, 2, true); // blockAlign + view.setUint16(34, 16, true); // bitsPerSample + view.setUint32(36, 0x61746164, true); // Subchunk2ID + view.setUint32(40, samples.length * 2, true); // subchunk2Size + + let offset = 44; + for (let i = 0; i < samples.length; ++i) { + view.setInt16(offset, samples[i], true); + offset += 2; + } + + return new Blob([view], {type: 'audio/wav'}); +} + +// this function is copied from +// https://github.com/awslabs/aws-lex-browser-audio-capture/blob/master/lib/worker.js#L46 +function downsampleBuffer(buffer, exportSampleRate) { + if (exportSampleRate === recordSampleRate) { + return buffer; + } + var sampleRateRatio = recordSampleRate / exportSampleRate; + var newLength = Math.round(buffer.length / sampleRateRatio); + var result = new Float32Array(newLength); + var offsetResult = 0; + var offsetBuffer = 0; + while (offsetResult < result.length) { + var nextOffsetBuffer = Math.round((offsetResult + 1) * sampleRateRatio); + var accum = 0, count = 0; + for (var i = offsetBuffer; i < nextOffsetBuffer && i < buffer.length; i++) { + accum += buffer[i]; + count++; + } + result[offsetResult] = accum / count; + offsetResult++; + offsetBuffer = nextOffsetBuffer; + } + return result; +}; diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/web/upload.html b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/upload.html new file mode 100644 index 000000000..afc1882a3 --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/upload.html @@ -0,0 +1,58 @@ + + + + + + + + + + + + + + Next-gen Kaldi demo (Upload file for recognition) + + + + + + + +

Recognition from a selected file

+
+
+ + +
+ +
+ + +
+
+ + + + + + + + + + + + diff --git a/sherpa/bin/pruned_stateless_emformer_rnnt2/web/upload.js b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/upload.js new file mode 100644 index 000000000..a2b0f8644 --- /dev/null +++ b/sherpa/bin/pruned_stateless_emformer_rnnt2/web/upload.js @@ -0,0 +1,60 @@ +/** +References +https://developer.mozilla.org/en-US/docs/Web/API/FileList +https://developer.mozilla.org/en-US/docs/Web/API/FileReader +https://javascript.info/arraybuffer-binary-arrays +https://developer.mozilla.org/zh-CN/docs/Web/API/WebSocket +https://developer.mozilla.org/en-US/docs/Web/API/WebSocket/send +*/ + +var socket; +function initWebSocket() { + socket = new WebSocket("ws://localhost:6008/"); + + // Connection opened + socket.addEventListener( + 'open', + function(event) { document.getElementById('file').disabled = false; }); + + // Connection closed + socket.addEventListener('close', function(event) { + document.getElementById('file').disabled = true; + initWebSocket(); + }); + + // Listen for messages + socket.addEventListener('message', function(event) { + document.getElementById('results').innerHTML = event.data; + console.log('Received message: ', event.data); + }); +} + +function onFileChange() { + var files = document.getElementById("file").files; + + if (files.length == 0) { + console.log('No file selected'); + return; + } + + console.log('files: ' + files); + + const file = files[0]; + console.log(file); + console.log('file.name ' + file.name); + console.log('file.type ' + file.type); + console.log('file.size ' + file.size); + + let reader = new FileReader(); + reader.onload = function() { + let view = new Uint8Array(reader.result); + console.log('bytes: ' + view.byteLength); + // we assume the input file is a wav file. + // TODO: add some checks here. + let body = view.subarray(44); + socket.send(body); + socket.send(JSON.stringify({'eof' : 1})); + }; + + reader.readAsArrayBuffer(file); +} diff --git a/sherpa/csrc/CMakeLists.txt b/sherpa/csrc/CMakeLists.txt index 81759af9e..738bc26d7 100644 --- a/sherpa/csrc/CMakeLists.txt +++ b/sherpa/csrc/CMakeLists.txt @@ -1,6 +1,7 @@ # Please sort the filenames alphabetically set(sherpa_srcs rnnt_beam_search.cc + rnnt_emformer_model.cc rnnt_model.cc ) diff --git a/sherpa/csrc/rnnt_beam_search.cc b/sherpa/csrc/rnnt_beam_search.cc index 720308484..ac06f9d94 100644 --- a/sherpa/csrc/rnnt_beam_search.cc +++ b/sherpa/csrc/rnnt_beam_search.cc @@ -16,6 +16,11 @@ * limitations under the License. */ +#include "sherpa/csrc/rnnt_beam_search.h" + +#include + +#include "sherpa/csrc/rnnt_emformer_model.h" #include "sherpa/csrc/rnnt_model.h" #include "torch/all.h" @@ -42,8 +47,8 @@ static void BuildDecoderInput(const std::vector> &hyps, } std::vector> GreedySearch( - RnntModel &model, torch::Tensor encoder_out, - torch::Tensor encoder_out_length) { + RnntModel &model, // NOLINT + torch::Tensor encoder_out, torch::Tensor encoder_out_length) { TORCH_CHECK(encoder_out.dim() == 3, "encoder_out.dim() is ", encoder_out.dim(), "Expected is 3"); TORCH_CHECK(encoder_out.scalar_type() == torch::kFloat, @@ -141,4 +146,50 @@ std::vector> GreedySearch( return ans; } +torch::Tensor StreamingGreedySearch(RnntEmformerModel &model, // NOLINT + torch::Tensor encoder_out, + torch::Tensor decoder_out, + std::vector> *hyps) { + TORCH_CHECK(encoder_out.dim() == 3, encoder_out.dim(), " vs ", 3); + TORCH_CHECK(decoder_out.dim() == 2, decoder_out.dim(), " vs ", 2); + + TORCH_CHECK(encoder_out.size(0) == decoder_out.size(0), encoder_out.size(0), + " vs ", decoder_out.size(0)); + + auto device = model.Device(); + int32_t blank_id = model.BlankId(); + int32_t unk_id = model.UnkId(); + int32_t context_size = model.ContextSize(); + + int32_t N = encoder_out.size(0); + int32_t T = encoder_out.size(1); + + auto decoder_input = + torch::full({N, context_size}, blank_id, + torch::dtype(torch::kLong) + .memory_format(torch::MemoryFormat::Contiguous)); + + for (int32_t t = 0; t != T; ++t) { + auto cur_encoder_out = encoder_out.index({torch::indexing::Slice(), t}); + + auto logits = model.ForwardJoiner(cur_encoder_out, decoder_out); + auto max_indices = logits.argmax(/*dim*/ -1).cpu(); + auto max_indices_accessor = max_indices.accessor(); + bool emitted = false; + for (int32_t n = 0; n != N; ++n) { + auto index = max_indices_accessor[n]; + if (index != blank_id && index != unk_id) { + emitted = true; + (*hyps)[n].push_back(index); + } + } + + if (emitted) { + BuildDecoderInput(*hyps, &decoder_input); + decoder_out = model.ForwardDecoder(decoder_input.to(device)).squeeze(1); + } + } + return decoder_out; +} + } // namespace sherpa diff --git a/sherpa/csrc/rnnt_beam_search.h b/sherpa/csrc/rnnt_beam_search.h index 72df87cf6..a2384d43a 100644 --- a/sherpa/csrc/rnnt_beam_search.h +++ b/sherpa/csrc/rnnt_beam_search.h @@ -20,6 +20,7 @@ #include +#include "sherpa/csrc/rnnt_emformer_model.h" #include "sherpa/csrc/rnnt_model.h" namespace sherpa { @@ -30,7 +31,7 @@ namespace sherpa { * * @param encoder_out Output from the encoder network. Its shape is * (batch_size, T, encoder_out_dim) and its dtype is - * torch::kFloat. + * torch::kFloat. It should be on the same device as `model`. * * @param encoder_out_lens A 1-D tensor containing the valid frames before * padding in `encoder_out`. Its dtype is torch.kLong @@ -42,8 +43,24 @@ namespace sherpa { * decoding results for the corresponding input in encoder_out. */ std::vector> GreedySearch( - RnntModel &model, torch::Tensor encoder_out, - torch::Tensor encoder_out_length); + RnntModel &model, // NOLINT + torch::Tensor encoder_out, torch::Tensor encoder_out_length); + +/** Greedy search for streaming recognition. + * + * @param model The stateless RNN-T Emformer model. + * @param encoder_out A 3-D tensor of shape (N, T, C). It should be on the same + * device as `model`. + * @param decoder_out A 2-D tensor of shape (N, C). It should be on the same + * device as `model`. + * @param hyps The decoded tokens. Note: It is modified in-place. + * + * @return Return the decoder output for the next chunk. + */ +torch::Tensor StreamingGreedySearch(RnntEmformerModel &model, // NOLINT + torch::Tensor encoder_out, + torch::Tensor decoder_out, + std::vector> *hyps); } // namespace sherpa diff --git a/sherpa/csrc/rnnt_emformer_model.cc b/sherpa/csrc/rnnt_emformer_model.cc new file mode 100644 index 000000000..867207445 --- /dev/null +++ b/sherpa/csrc/rnnt_emformer_model.cc @@ -0,0 +1,103 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "sherpa/csrc/rnnt_emformer_model.h" + +#include + +namespace sherpa { + +RnntEmformerModel::RnntEmformerModel(const std::string &filename, + torch::Device device /*=torch::kCPU*/, + bool optimize_for_inference /*=false*/) + : device_(device) { + model_ = torch::jit::load(filename, device); + model_.eval(); + if (optimize_for_inference) { + model_ = torch::jit::optimize_for_inference(model_); + } + + encoder_ = model_.attr("encoder").toModule(); + decoder_ = model_.attr("decoder").toModule(); + joiner_ = model_.attr("joiner").toModule(); + + blank_id_ = decoder_.attr("blank_id").toInt(); + + unk_id_ = blank_id_; + if (decoder_.hasattr("unk_id")) { + unk_id_ = decoder_.attr("unk_id").toInt(); + } + + context_size_ = decoder_.attr("context_size").toInt(); + segment_length_ = encoder_.attr("segment_length").toInt(); + right_context_length_ = encoder_.attr("right_context_length").toInt(); +} + +std::pair +RnntEmformerModel::StreamingForwardEncoder( + const torch::Tensor &features, const torch::Tensor &features_length, + torch::optional states /*= torch::nullopt*/) { + // It contains [torch.Tensor, torch.Tensor, List[List[torch.Tensor]] + // which are [encoder_out, encoder_out_len, states] + // + // We skip the second entry `encoder_out_len` since we assume the + // feature input are of fixed chunk size and there are no paddings. + // We can figure out `encoder_out_len` from `encoder_out`. + torch::IValue ivalue = encoder_.run_method("streaming_forward", features, + features_length, states); + auto tuple_ptr = ivalue.toTuple(); + torch::Tensor encoder_out = tuple_ptr->elements()[0].toTensor(); + + torch::List list = tuple_ptr->elements()[2].toList(); + int32_t num_layers = list.size(); + + std::vector> next_states; + next_states.reserve(num_layers); + + for (int32_t i = 0; i != num_layers; ++i) { + next_states.emplace_back( + c10::impl::toTypedList(list.get(i).toList()).vec()); + } + + return {encoder_out, next_states}; +} + +RnntEmformerModel::State RnntEmformerModel::GetEncoderInitStates() { + torch::IValue ivalue = encoder_.run_method("get_init_state", device_); + torch::List list = ivalue.toList(); + int32_t num_layers = list.size(); + State states; + states.reserve(num_layers); + for (int32_t i = 0; i != num_layers; ++i) { + states.emplace_back( + c10::impl::toTypedList(list.get(i).toList()).vec()); + } + return states; +} + +torch::Tensor RnntEmformerModel::ForwardDecoder( + const torch::Tensor &decoder_input) { + return decoder_.run_method("forward", decoder_input, /*need_pad*/ false) + .toTensor(); +} + +torch::Tensor RnntEmformerModel::ForwardJoiner( + const torch::Tensor &encoder_out, const torch::Tensor &decoder_out) { + return joiner_.run_method("forward", encoder_out, decoder_out).toTensor(); +} + +} // namespace sherpa diff --git a/sherpa/csrc/rnnt_emformer_model.h b/sherpa/csrc/rnnt_emformer_model.h new file mode 100644 index 000000000..b5dc11ce9 --- /dev/null +++ b/sherpa/csrc/rnnt_emformer_model.h @@ -0,0 +1,95 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SHERPA_CSRC_RNNT_EMFORMER_MODEL_H_ +#define SHERPA_CSRC_RNNT_EMFORMER_MODEL_H_ + +#include +#include +#include + +#include "torch/script.h" + +namespace sherpa { + +/** It wraps a torch script model, which is from + * pruned_stateless_emformer_rnnt2/model.py within icefall. + */ +class RnntEmformerModel { + public: + /** + * @param filename Path name of the torch script model. + * @param device The model will be moved to this device + * @param optimize_for_inference true to invoke + * torch::jit::optimize_for_inference(). + */ + explicit RnntEmformerModel(const std::string &filename, + torch::Device device = torch::kCPU, + bool optimize_for_inference = false); + + ~RnntEmformerModel() = default; + + using State = std::vector>; + + std::pair StreamingForwardEncoder( + const torch::Tensor &features, const torch::Tensor &features_length, + torch::optional states = torch::nullopt); + + State GetEncoderInitStates(); + + /** Run the decoder network. + * + * @param decoder_input A 2-D tensor of shape (N, U). + * @return Return a tensor of shape (N, U, decoder_dim) + */ + torch::Tensor ForwardDecoder(const torch::Tensor &decoder_input); + + /** Run the joiner network. + * + * @param encoder_out A 2-D tensor of shape (N, C). + * @param decoder_out A 2-D tensor of shape (N, C). + * @return Return a tensor of shape (N, vocab_size) + */ + torch::Tensor ForwardJoiner(const torch::Tensor &encoder_out, + const torch::Tensor &decoder_out); + + torch::Device Device() const { return device_; } + int32_t BlankId() const { return blank_id_; } + int32_t UnkId() const { return unk_id_; } + int32_t ContextSize() const { return context_size_; } + int32_t SegmentLength() const { return segment_length_; } + int32_t RightContextLength() const { return right_context_length_; } + + private: + torch::jit::Module model_; + + // The following modules are just aliases to modules in model_ + torch::jit::Module encoder_; + torch::jit::Module decoder_; + torch::jit::Module joiner_; + + torch::Device device_; + int32_t blank_id_; + int32_t unk_id_; + int32_t context_size_; + int32_t segment_length_; + int32_t right_context_length_; +}; + +} // namespace sherpa + +#endif // SHERPA_CSRC_RNNT_EMFORMER_MODEL_H_ diff --git a/sherpa/csrc/rnnt_model.h b/sherpa/csrc/rnnt_model.h index 024b82c41..6dcf84f63 100644 --- a/sherpa/csrc/rnnt_model.h +++ b/sherpa/csrc/rnnt_model.h @@ -18,7 +18,7 @@ #ifndef SHERPA_CSRC_RNNT_MODEL_H_ #define SHERPA_CSRC_RNNT_MODEL_H_ -#include +#include #include #include "torch/script.h" @@ -53,7 +53,7 @@ class RnntModel { * @param features A 3-D tensor of shape (N, T, C). * @param features_length A 1-D tensor of shape (N,) containing the number of * valid frames in `features`. - * @return Return a tuple containing two tensors: + * @return Return a pair containing two tensors: * - encoder_out, a 3-D tensor of shape (N, T, C) * - encoder_out_length, a 1-D tensor of shape (N,) containing the * number of valid frames in `encoder_out`. @@ -91,26 +91,6 @@ class RnntModel { */ torch::Tensor ForwardDecoderProj(const torch::Tensor &decoder_out); - /** TODO(fangjun): Implement it - * - * Run the encoder network in a streaming fashion. - * - * @param features A 3-D tensor of shape (N, T, C). - * @param features_length A 1-D tensor of shape (N,) containing the number of - * valid frames in `features`. - * @param prev_state It contains the previous state from the encoder network. - * - * @return Return a tuple containing 3 entries: - * - encoder_out, a 3-D tensor of shape (N, T, C) - * - encoder_out_length, a 1-D tensor of shape (N,) containing the - * number of valid frames in encoder_out - * - next_state, the state for the encoder network. - */ - std::tuple - StreamingForwardEncoder(const torch::Tensor &features, - const torch::Tensor &feature_lengths, - torch::IValue prev_state); - private: torch::jit::Module model_; diff --git a/sherpa/python/csrc/CMakeLists.txt b/sherpa/python/csrc/CMakeLists.txt index 282b1b771..b71446edc 100644 --- a/sherpa/python/csrc/CMakeLists.txt +++ b/sherpa/python/csrc/CMakeLists.txt @@ -3,6 +3,7 @@ add_definitions(-DTORCH_API_INCLUDE_EXTENSION_H) # Please sort files alphabetically pybind11_add_module(_sherpa rnnt_beam_search.cc + rnnt_emformer_model.cc rnnt_model.cc sherpa.cc ) diff --git a/sherpa/python/csrc/rnnt_beam_search.cc b/sherpa/python/csrc/rnnt_beam_search.cc index c61819a19..c689d0c90 100644 --- a/sherpa/python/csrc/rnnt_beam_search.cc +++ b/sherpa/python/csrc/rnnt_beam_search.cc @@ -15,17 +15,32 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "sherpa/csrc/rnnt_beam_search.h" - #include "sherpa/python/csrc/rnnt_beam_search.h" + +#include +#include + +#include "sherpa/csrc/rnnt_beam_search.h" #include "torch/torch.h" namespace sherpa { -void PybindRnntBeamSearch(py::module &m) { +void PybindRnntBeamSearch(py::module &m) { // NOLINT m.def("greedy_search", &GreedySearch, py::arg("model"), py::arg("encoder_out"), py::arg("encoder_out_length"), py::call_guard()); + + m.def( + "streaming_greedy_search", + [](RnntEmformerModel &model, torch::Tensor encoder_out, + torch::Tensor decoder_out, std::vector> &hyps) + -> std::pair>> { + decoder_out = + StreamingGreedySearch(model, encoder_out, decoder_out, &hyps); + return {decoder_out, hyps}; + }, + py::arg("model"), py::arg("encoder_out"), py::arg("decoder_out"), + py::arg("hyps"), py::call_guard()); } } // namespace sherpa diff --git a/sherpa/python/csrc/rnnt_beam_search.h b/sherpa/python/csrc/rnnt_beam_search.h index a35367bbf..fc24f8da3 100644 --- a/sherpa/python/csrc/rnnt_beam_search.h +++ b/sherpa/python/csrc/rnnt_beam_search.h @@ -22,7 +22,7 @@ namespace sherpa { -void PybindRnntBeamSearch(py::module &m); +void PybindRnntBeamSearch(py::module &m); // NOLINT } // namespace sherpa diff --git a/sherpa/python/csrc/rnnt_emformer_model.cc b/sherpa/python/csrc/rnnt_emformer_model.cc new file mode 100644 index 000000000..ad98b8776 --- /dev/null +++ b/sherpa/python/csrc/rnnt_emformer_model.cc @@ -0,0 +1,64 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "sherpa/python/csrc/rnnt_emformer_model.h" + +#include +#include + +#include "sherpa/csrc/rnnt_emformer_model.h" +#include "torch/torch.h" + +namespace sherpa { + +void PybindRnntEmformerModel(py::module &m) { // NOLINT + using PyClass = RnntEmformerModel; + py::class_(m, "RnntEmformerModel") + .def(py::init([](const std::string &filename, + py::object device = py::str("cpu"), + bool optimize_for_inference = + false) -> std::unique_ptr { + std::string device_str = + device.is_none() ? "cpu" : py::str(device); + return std::make_unique( + filename, torch::Device(device_str), optimize_for_inference); + }), + py::arg("filename"), py::arg("device") = py::str("cpu"), + py::arg("optimize_for_inference") = false) + .def("encoder_streaming_forward", &PyClass::StreamingForwardEncoder, + py::arg("features"), py::arg("features_length"), + py::arg("states") = py::none(), + py::call_guard()) + .def("decoder_forward", &PyClass::ForwardDecoder, + py::arg("decoder_input"), py::call_guard()) + .def("get_encoder_init_states", &PyClass::GetEncoderInitStates, + py::call_guard()) + .def_property_readonly("device", + [](const PyClass &self) -> py::object { + py::object ans = + py::module_::import("torch").attr("device"); + return ans(self.Device().str()); + }) + .def_property_readonly("blank_id", &PyClass::BlankId) + .def_property_readonly("context_size", &PyClass::ContextSize) + .def_property_readonly("segment_length", &PyClass::SegmentLength) + .def_property_readonly("right_context_length", + &PyClass::RightContextLength); +} + +} // namespace sherpa diff --git a/sherpa/python/csrc/rnnt_emformer_model.h b/sherpa/python/csrc/rnnt_emformer_model.h new file mode 100644 index 000000000..58813f359 --- /dev/null +++ b/sherpa/python/csrc/rnnt_emformer_model.h @@ -0,0 +1,29 @@ +/** + * Copyright (c) 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef SHERPA_PYTHON_CSRC_RNNT_EMFORMER_MODEL_H_ +#define SHERPA_PYTHON_CSRC_RNNT_EMFORMER_MODEL_H_ + +#include "sherpa/python/csrc/sherpa.h" + +namespace sherpa { + +void PybindRnntEmformerModel(py::module &m); // NOLINT + +} // namespace sherpa + +#endif // SHERPA_PYTHON_CSRC_RNNT_EMFORMER_MODEL_H_ diff --git a/sherpa/python/csrc/rnnt_model.cc b/sherpa/python/csrc/rnnt_model.cc index 089d45cdb..06d3660a1 100644 --- a/sherpa/python/csrc/rnnt_model.cc +++ b/sherpa/python/csrc/rnnt_model.cc @@ -19,13 +19,14 @@ #include "sherpa/python/csrc/rnnt_model.h" #include +#include #include "sherpa/csrc/rnnt_model.h" #include "torch/torch.h" namespace sherpa { -void PybindRnntModel(py::module &m) { +void PybindRnntModel(py::module &m) { // NOLINT using PyClass = RnntModel; py::class_(m, "RnntModel") .def(py::init([](const std::string &filename, diff --git a/sherpa/python/csrc/rnnt_model.h b/sherpa/python/csrc/rnnt_model.h index 5a8ce700a..cf2731936 100644 --- a/sherpa/python/csrc/rnnt_model.h +++ b/sherpa/python/csrc/rnnt_model.h @@ -22,7 +22,7 @@ namespace sherpa { -void PybindRnntModel(py::module &m); +void PybindRnntModel(py::module &m); // NOLINT } // namespace sherpa diff --git a/sherpa/python/csrc/sherpa.cc b/sherpa/python/csrc/sherpa.cc index 5e5c90ef6..5de5a0dea 100644 --- a/sherpa/python/csrc/sherpa.cc +++ b/sherpa/python/csrc/sherpa.cc @@ -19,6 +19,7 @@ #include "sherpa/python/csrc/sherpa.h" #include "sherpa/python/csrc/rnnt_beam_search.h" +#include "sherpa/python/csrc/rnnt_emformer_model.h" #include "sherpa/python/csrc/rnnt_model.h" namespace sherpa { @@ -27,6 +28,7 @@ PYBIND11_MODULE(_sherpa, m) { m.doc() = "pybind11 binding of sherpa"; PybindRnntModel(m); + PybindRnntEmformerModel(m); PybindRnntBeamSearch(m); } diff --git a/sherpa/python/sherpa/__init__.py b/sherpa/python/sherpa/__init__.py index b15022aeb..10f50e2ea 100644 --- a/sherpa/python/sherpa/__init__.py +++ b/sherpa/python/sherpa/__init__.py @@ -1,3 +1,2 @@ import torch - -from _sherpa import RnntModel, greedy_search +from _sherpa import RnntEmformerModel, RnntModel, greedy_search, streaming_greedy_search