-
Notifications
You must be signed in to change notification settings - Fork 110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support pre-trained CTC models from NeMo #332
Conversation
@titu1994 You may find this pull-request interesting and helpful. |
Caution: Therefore, while creating See the code below
|
This is fantastic ! Thank you very much for this integration, and let me know how I can help (I'm adding docs as discussed in other thread). We could potentially add a link to your example in our decoding section docs if it supports CTC models with both char/subword Tokenizer. |
Will update the doc in a separate PR. |
There are several issues about the torchscript models from NeMo.
It turns out the comment at def forward_for_export(self, input, length=None, cache_last_channel=None, cache_last_time=None):
"""
This forward is used when we need to export the model to ONNX format.
Inputs cache_last_channel and cache_last_time are needed to be passed for exporting streaming models.
When they are passed, it just passes the inputs through the encoder part and currently the ONNX conversion does not fully work for this case.
Args:
input: Tensor that represents a batch of raw audio signals,
of shape [B, T]. T here represents timesteps.
length: Vector of length B, that contains the individual lengths of the audio sequences. The comment says the shape of It took me really, really, a hard time to figure that out.
|
Nemo has Neural types in each of it's models, that's what you should use to determine shape. You can do this by calling model.input_types and model.output_types which will both return dictionary of neural types, and usually also note the order of tensor shape inside of each arg, along with arg name if you want to pass args by key:value pairs.
|
To give you an example, the following code import torchaudio
citrinet_zh = nemo_asr.models.EncDecCTCModel.from_pretrained('stt_zh_citrinet_512');
citrinet_zh.export("model.pt")
samples, sample_rate = torchaudio.load("./BAC009S0764W0121.wav")
print('samples', samples.shape, sample_rate)
features, features_len = citrinet_zh.preprocessor(input_signal=samples, length=torch.tensor([samples.shape[1]]))
print('features', features.shape, features_len)
model = torch.jit.load("model.pt")
log_probs = model(features, features_len)
print(log_probs.shape) has the following output
You can see that the model returns only a single output. The model has merged the encoder and decoder into a single module. It would be nice if the model can also return the length of |
Hmm I see. I will ask our team members if it's possible to change this, as the output shape requirement needs to be optional (Riva does not want it usually). Should be doable, RNNT already does support this, but we need to check how to implement this without damaging preexisting models and exposed paths. |
Does Riva support batch CTC decoding? We need the length information for batch CTC decoding in sherpa. |
Thanks! |
Btw, to note - one a model calls .export() it is considered corrupted model. I would suggest not to trust the output of such model, instead delete it and load the jit model and restore another copy of the Nemo model if you're in need of the preprocessor. Another thing is, if you have Torchaudio [installed, you can export the preprocessor too - https://github.com/NVIDIA/NeMo/pull/5512 Dunno why but I forgot to add it to the docs |
Yep.
Their preprocessor internally keeps track of it for CTC so it somehow works but I'm not sure of the internals. |
Does Riva also use the torchscript model? If so, how does Riva know the length of |
FYI: The documentation of pre-trained models from NeMo is available at The following is a screenshot: We can convert more models if needed. |
Riva supports both onnx and TS output, as to how they support it without explicit export, no idea. It's easy enough to estimate the seq length by dividing by model stride (so the length from the preprocessor //4 for conformer or 8 for Citrinet) should give you nearly correct seq length. |
That's great ! Can you try one of the Conformer CTC models ? That would be the current state of the art models in NeMo trained on much more data than the Citrinet |
The C++ code is fairly generic and all it takes is a Is it possible to read the subsampling factor from the torchscript model? If not, could you add some attributes to the model before exporting so that we can read them in C++ code. In icefall, we add attributes to the model, such as vocab size and subsampling factor so that we can read them in C++ within sherpa.
Thanks. I will try the conformer model. |
Hi @csukuangfj , Everything that I have tried till now is using conformer ctc models only. FYI. Thanks |
@titu1994 I am having the following error while exporting a conformer ctc model to torchscript. The code: m = nemo_asr.models.EncDecCTCModelBPE.from_pretrained('stt_en_conformer_ctc_small')
m.export("model.pt") The error for the above code:
Do you have any suggestions about how to fix it? I am using the following code to install NeMo in a google colab notebook: ## Install dependencies
!pip install wget
!apt-get install sox libsndfile1 ffmpeg
!pip install text-unidecode
# ## Install NeMo
BRANCH = 'main'
!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]
## Install TorchAudio
!pip install torchaudio>=0.10.0 -f https://download.pytorch.org/whl/torch_stable.html
## Grab the config we'll use in this example
!mkdir configs |
I find a solution from pytorch/pytorch#81085 (comment)
in |
I have updated the documentation to include conformer ctc models from NeMo. I also add a section to describe how to export CTC models from NeMo to sherpa. Please see By the way, could you add Conformer CTC models for more languages, e.g., Chinese ? |
Sure we can look into this. Is there an example or documentation of how to add attributes to torchscript export ? I have notified the team about the issue of torchscript export of conformer, thanks for finding out ! We have export test for conformer onnx since that's usually more efficient but we want to support both onnx and TS in Nemo for compatibility. |
Yes absolutely ! We have a ton of languages with conformer support. Here is a non exhaustive list of all models on NGC for languages we currently support - https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/scores.html Huggingface just has the most popular models but we can add more to HF if there is request for it |
If a model has some scalar attributes, then after exporting to torchscript, the scalar attributes will be kept in the resulting exported model. For instance, in icefall, the decoder model has the following attributes: self.context_size = context_size
self.vocab_size = vocab_size And we can access the attributes of the exported decoder model in sherpa using the following code
|
Unfortunately, the list does not have conformer CTC models for Chinese. |
Thanks! |
Oh it seems I was mistaken, we have a conformer transducer large trained on mandarin but not CTC. We found Citrinet to do better in cer so we didn't release the checkpoint. I suppose we could look into it in the future |
I see. Thanks! |
Fixes #303
Fixes #238
TODOs
Usage example
We have converted Citrinet-512 from NeMo:
https://catalog.ngc.nvidia.com/orgs/nvidia/teams/nemo/models/stt_en_citrinet_512
The model is saved at
https://huggingface.co/csukuangfj/sherpa-nemo-ctc-en-citrinet-512
In the following,we describe how to use sherpa to decode sound files with pre-trained CTC models from NeMo.
Build sherpa
Download the pre-trained model
Use the pre-trained model
cd /path/to/sherpa ./build/bin/sherpa-offline \ --nn-model=./sherpa-nemo-ctc-en-citrinet-512/model.pt \ --tokens=./sherpa-nemo-ctc-en-citrinet-512/tokens.txt \ --use-gpu=false \ --modified=false \ --nemo-normalize=per_feature \ ./sherpa-nemo-ctc-en-citrinet-512/test_wavs/0.wav
You should see the following output: