Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming Emformer stateless RNN-T. #390

Merged
merged 4 commits into from
Jun 1, 2022

Conversation

csukuangfj
Copy link
Collaborator

This PR uses the Emformer model from torchaudio, which requires torchaudio >= 0.11.0.

Training command

./pruned_stateless_emformer_rnnt2/train.py \
  --world-size 8 \
  --num-epochs 40 \
  --start-epoch 1 \
  --exp-dir pruned_stateless_emformer_rnnt2/exp-full \
  --full-libri 1 \
  --use-fp16 0 \
  --max-duration 200 \
  --prune-range 5 \
  --lm-scale 0.25 \
  --master-port 12358 \
  --num-encoder-layers 18 \
  --left-context-length 128 \
  --segment-length 8 \
  --right-context-length 4

Decoding command

for m in greedy_search fast_beam_search modified_beam_search; do
  for epoch in 39; do
    for avg in 6; do
      ./pruned_stateless_emformer_rnnt2/decode.py \
        --epoch $epoch \
        --avg $avg \
        --use-averaged-model 1 \
        --exp-dir pruned_stateless_emformer_rnnt2/exp-full \
        --max-duration 50 \
        --decoding-method $m \
        --num-encoder-layers 18 \
        --left-context-length 128 \
        --segment-length 8 \
        --right-context-length 4
    done
  done
done
method test-clean test-other comment
greedy search 4.28 11.42 epoch 39, avg 6
modified beam search 4.22 11.16 epoch 39, avg 6
fast beam search 4.29 11.26 epoch 39, avg 6

The baseline is from
https://github.com/pytorch/audio/blob/main/examples/asr/emformer_rnnt/README.md

  • test-clean: 4.56
  • test-other: 10.66

Note that the baseline is trained for 120 epochs, with 32 GPUs

Also, the baseline uses vocab size 4098.

Will switch to #389

The pretrained model can be used in k2-fsa/sherpa#6 for streaming ASR recognition.

I am uploading the training logs, decoding results, decoding logs, and pretrained model to hugging face.

@csukuangfj
Copy link
Collaborator Author

I have uploaded the pretrained model, training logs, decoding logs, and decoding results to

https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01

You can use the pretrained model in https://github.com/k2-fsa/sherpa, which is an ASR server in Python supporting both streaming and non-streaming ASR.

The following is a YouTube video demonstrating its use in sherpa.

https://www.youtube.com/watch?v=z7HgaZv5W0U

@csukuangfj csukuangfj merged commit fbfc98f into k2-fsa:master Jun 1, 2022
@csukuangfj csukuangfj deleted the streaming-emformer-2022-05-27 branch June 1, 2022 06:31
@Tomiinek
Copy link

Tomiinek commented Oct 10, 2022

Hi @csukuangfj , I am trying to use the non-torchscripted checkpoint that you released on HuggingFace (pretrained-epoch-39-avg-6-use-averaged-model-1.pt) but I am getting poor results when running decode.py with it (~100% WER on both libri other and clean with on-the-fly features, randomly initialized model gives me ~200% and models trained from scratch on my own have reasonable WER values). Could you please make sure that the checkpoint works?

@csukuangfj
Copy link
Collaborator Author

What is your decoding command?

@csukuangfj
Copy link
Collaborator Author

I just rechecked it with the following commands:

cd egs/librispeech/ASR/
mkdir t
cd t
ln -s /ceph-fj/fangjun/open-source-2/icefall-models//icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/exp/pretrained-epoch-39-avg-6-use-averaged-model-1.pt epoch-99.pt

ln -s /ceph-fj/fangjun/open-source-2/icefall-models//icefall-asr-librispeech-pruned-stateless-emformer-rnnt2-2022-06-01/data/lang_bpe_500/bpe.model ./

cd ../

./pruned_stateless_emformer_rnnt2/decode.py \
  --epoch 99 \
  --avg 1 \
  --use-averaged-model 0 \
  --exp-dir ./t/ \
  --bpe-model ./t/bpe.model \
  --max-duration 50 \
  --decoding-method greedy_search \
  --num-encoder-layers 18 \
  --left-context-length 128 \
  --segment-length 8 \
  --right-context-length 4

It gives me the following output:

/ceph-fj/fangjun/open-source-2/audio/torchaudio/_extension.py:83: UserWarning: torchaudio C++ extension is not available.
  warnings.warn("torchaudio C++ extension is not available.")
2022-10-10 22:20:21,859 INFO [decode.py:508] Decoding started
2022-10-10 22:20:21,859 INFO [decode.py:514] Device: cuda:0
2022-10-10 22:20:21,863 INFO [decode.py:524] {'best_train_loss': inf, 'best_valid_loss': inf, 'best_train_epoch': -1, 'best_valid_epoch': -1, 'batch_i
dx_train': 0, 'log_interval': 50, 'reset_interval': 200, 'valid_interval': 3000, 'feature_dim': 80, 'subsampling_factor': 4, 'vgg_frontend': False, 'e
mbedding_dim': 512, 'warm_step': 80000, 'env_info': {'k2-version': '1.21', 'k2-build-type': 'Release', 'k2-with-cuda': True, 'k2-git-sha1': '4810e00d8
738f1a21278b0156a42ff396a2d40ac', 'k2-git-date': 'Fri Oct 7 19:35:03 2022', 'lhotse-version': '1.9.0.dev+missing.version.file', 'torch-version': '1.10
.0+cu102', 'torch-cuda-available': True, 'torch-cuda-version': '10.2', 'python-version': '3.8', 'icefall-git-branch': 'master', 'icefall-git-sha1': '0
99cd3a-clean', 'icefall-git-date': 'Tue Sep 20 22:52:49 2022', 'icefall-path': '/k2-dev/fangjun/open-source/icefall-master-2', 'k2-path': '/k2-dev/fan
gjun/open-source/k2-master/k2/python/k2/__init__.py', 'lhotse-path': '/k2-dev/fangjun/open-source/lhotse-master-2/lhotse/__init__.py', 'hostname': 'de
-74279-k2-train-6-0701202559-8476c48f5f-xmr4s', 'IP address': '10.177.28.74'}, 'epoch': 99, 'iter': 0, 'avg': 1, 'use_averaged_model': False, 'exp_dir
': PosixPath('t'), 'bpe_model': './t/bpe.model', 'decoding_method': 'greedy_search', 'beam_size': 4, 'beam': 4, 'max_contexts': 4, 'max_states': 8, 'c
ontext_size': 2, 'max_sym_per_frame': 1, 'attention_dim': 512, 'nhead': 8, 'dim_feedforward': 2048, 'num_encoder_layers': 18, 'left_context_length': 1
28, 'segment_length': 8, 'right_context_length': 4, 'memory_size': 0, 'full_libri': True, 'manifest_dir': PosixPath('data/fbank'), 'max_duration': 50,
 'bucketing_sampler': True, 'num_buckets': 30, 'concatenate_cuts': False, 'duration_factor': 1.0, 'gap': 1.0, 'on_the_fly_feats': False, 'shuffle': Tr
ue, 'drop_last': True, 'return_cuts': True, 'num_workers': 2, 'enable_spec_aug': True, 'spec_aug_time_warp_factor': 80, 'enable_musan': True, 'input_s
trategy': 'PrecomputedFeatures', 'res_dir': PosixPath('t/greedy_search'), 'suffix': 'epoch-99-avg-1-context-2-max-sym-per-frame-1', 'blank_id': 0, 'un
k_id': 2, 'vocab_size': 500}

2022-10-10 22:20:21,863 INFO [decode.py:526] About to create model
2022-10-10 22:20:23,015 INFO [checkpoint.py:112] Loading checkpoint from t/epoch-99.pt
2022-10-10 22:20:29,846 INFO [decode.py:615] Number of model parameters: 65390556
2022-10-10 22:20:29,847 INFO [asr_datamodule.py:444] About to get test-clean cuts
2022-10-10 22:20:29,866 INFO [asr_datamodule.py:451] About to get test-other cuts
2022-10-10 22:20:31,493 INFO [decode.py:418] batch 0/?, cuts processed until now is 3
2022-10-10 22:21:02,966 INFO [decode.py:418] batch 50/?, cuts processed until now is 257
2022-10-10 22:21:33,340 INFO [decode.py:418] batch 100/?, cuts processed until now is 514
2022-10-10 22:22:03,113 INFO [decode.py:418] batch 150/?, cuts processed until now is 809
2022-10-10 22:22:32,767 INFO [decode.py:418] batch 200/?, cuts processed until now is 1099
2022-10-10 22:23:05,717 INFO [decode.py:418] batch 250/?, cuts processed until now is 1356
2022-10-10 22:23:31,976 INFO [decode.py:418] batch 300/?, cuts processed until now is 1710
2022-10-10 22:24:02,622 INFO [decode.py:418] batch 350/?, cuts processed until now is 1985
2022-10-10 22:24:32,952 INFO [decode.py:418] batch 400/?, cuts processed until now is 2246
2022-10-10 22:25:07,047 INFO [decode.py:418] batch 450/?, cuts processed until now is 2486
2022-10-10 22:25:35,278 INFO [decode.py:436] The transcripts are stored in t/greedy_search/recogs-test-clean-greedy_search-epoch-99-avg-1-context-2-ma
x-sym-per-frame-1.txt
2022-10-10 22:25:35,433 INFO [utils.py:428] [test-clean-greedy_search] %WER 4.28% [2251 / 52576, 258 ins, 181 del, 1812 sub ]
2022-10-10 22:25:35,771 INFO [decode.py:449] Wrote detailed error stats to t/greedy_search/errs-test-clean-greedy_search-epoch-99-avg-1-context-2-max-
sym-per-frame-1.txt
2022-10-10 22:25:35,772 INFO [decode.py:466]
For test-clean, WER of different settings are:
greedy_search   4.28    best for test-clean

@csukuangfj
Copy link
Collaborator Author

Note: I am using the Emformer model from the following commit of the torchaudio repo:

commit 93024ace026e6e0a30449a932fa30cfd49258251 (HEAD)
Author: Caroline Chen <[email protected]>
Date:   Tue May 31 19:27:28 2022 -0700

    Move CTC beam search decoder to beta (#2410)

@Tomiinek
Copy link

Thanks for your prompt reply, it was very helpful! I used mismatching BPE model 🙄 Sorry for bothering you

@csukuangfj
Copy link
Collaborator Author

Thanks for your prompt reply, it was very helpful! I used mismatching BPE model 🙄 Sorry for bothering you

Never mind. Glad to hear it works for you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants