-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ONNX export and inference #8
Conversation
Hi @mush42, |
@rmcpantoja Fixed in 2c21a0e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess, it should have been useful to have a test suite, I will work on that. Also, could you divert the PR to the dev
branch?
…) since the former is dedicated to this use case.
I am sorry, I am a bit delayed with the review, I will do it by the end of this week. Apologies for the delay. |
No problem, take your time. Best |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Thank you very much for this.
Glad to be of help. Let's take a swigg of that tasty Matcha 🙂 |
What does this PR do?
Enables exporting Matcha models to ONNX, and inferenceing of exported models.
For neural TTS models, my experience with two different TTS systems shows that onnxruntime is 2x-3x faster than torch, and significantly faster than torch script.
For on-device inference, ONNX has an excellent deployment story as demonstrated by Piper TTS.
This PR adds two new dependencies:
onnx
: required bytorch's
ONNX exporteronnxruntime
: needed for inferecing exported modelsNote: only for exporting, torch>=2.1.0 is needed for export since the
scaled_product_attention
operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually from pre-releases.Fixes #6
Changes
Since ONNX export depends on torch tracing, and since tracing cannot tract Python values; three stylistic changes were made to satisfy that condition. In essence, pass tensors as is without converting them to integers using int(), and use alternatives to Python
if
statements since they are not traceable.Breaking changes
None
Before submitting
pytest
command?pre-commit run -a
command?