Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX export and inference #8

Merged
merged 5 commits into from
Sep 29, 2023
Merged

ONNX export and inference #8

merged 5 commits into from
Sep 29, 2023

Conversation

mush42
Copy link
Contributor

@mush42 mush42 commented Sep 24, 2023

What does this PR do?

Enables exporting Matcha models to ONNX, and inferenceing of exported models.

For neural TTS models, my experience with two different TTS systems shows that onnxruntime is 2x-3x faster than torch, and significantly faster than torch script.

For on-device inference, ONNX has an excellent deployment story as demonstrated by Piper TTS.

This PR adds two new dependencies:

  • onnx: required by torch's ONNX exporter
  • onnxruntime: needed for inferecing exported models

Note: only for exporting, torch>=2.1.0 is needed for export since the scaled_product_attention operator is not exportable in older versions. Until the final version is released, those who want to export their models must install torch>=2.1.0 manually from pre-releases.

Fixes #6

Changes

Since ONNX export depends on torch tracing, and since tracing cannot tract Python values; three stylistic changes were made to satisfy that condition. In essence, pass tensors as is without converting them to integers using int(), and use alternatives to Python if statements since they are not traceable.

Breaking changes

None

Before submitting

  • Did you make sure title is self-explanatory and the description concisely explains the PR?
  • Did you make sure your PR does only one thing, instead of bundling different changes together?
  • Did you list all the breaking changes introduced by this pull request?
  • Did you test your PR locally with pytest command?
  • Did you run pre-commit hooks with pre-commit run -a command?

@rmcpantoja
Copy link

Hi @mush42,
The export with vocoder is not working. I'm using torch 2.1.0.dev20230905+cu118, and here's the result:
Command: python3 -m matcha.onnx.export matcha_ljspeech.ckpt ljspeech_with_vocoder.onnx --n-timesteps 5 --vocoder-name "hifigan_T2_v1" --vocoder-checkpoint "generator_v1"
Output:
[🍵] Loading Matcha checkpoint from matcha_ljspeech.ckpt
Setting n_timesteps to 5
[!] Loading matcha_ljspeech!
[+] matcha_ljspeech loaded!
[!] Loading hifigan_T2_v1!
/usr/local/lib/python3.10/dist-packages/torch/nn/utils/weight_norm.py:30: UserWarning: torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.
warnings.warn("torch.nn.utils.weight_norm is deprecated in favor of torch.nn.utils.parametrizations.weight_norm.")
Removing weight norm...
[+] hifigan_T2_v1 loaded!
Embedding the vocoder in the ONNX graph
/content/Matcha-TTS/matcha/models/components/flow_matching.py:77: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
while steps <= len(t_span) - 1:
/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:498: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if current_length != target_length:
/usr/local/lib/python3.10/dist-packages/diffusers/models/attention_processor.py:513: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if attention_mask.shape[0] < batch_size * head_size:
/content/Matcha-TTS/matcha/models/components/decoder.py:149: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert inputs.shape[1] == self.channels
/content/Matcha-TTS/matcha/models/components/flow_matching.py:83: TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
if steps < len(t_span) - 1:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/content/Matcha-TTS/matcha/onnx/export.py", line 181, in
main()
File "/content/Matcha-TTS/matcha/onnx/export.py", line 167, in main
model.to_onnx(
File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/lightning/pytorch/core/module.py", line 1377, in to_onnx
torch.onnx.export(self, input_sample, file_path, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1596, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
outs = ONNXTracedModule(
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 133, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 124, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
result = self.forward(*input, **kwargs)
File "/content/Matcha-TTS/matcha/onnx/export.py", line 30, in forward
wavs = self.vocoder(mel).clamp(-1, 1)
TypeError: 'tuple' object is not callable

@mush42
Copy link
Contributor Author

mush42 commented Sep 24, 2023

@rmcpantoja Fixed in 2c21a0e

Copy link
Owner

@shivammehta25 shivammehta25 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess, it should have been useful to have a test suite, I will work on that. Also, could you divert the PR to the dev branch?

matcha/utils/model.py Outdated Show resolved Hide resolved
matcha/onnx/export.py Show resolved Hide resolved
matcha/onnx/infer.py Show resolved Hide resolved
@mush42 mush42 changed the base branch from main to dev September 26, 2023 12:54
…) since the former is dedicated to this use case.
@shivammehta25
Copy link
Owner

I am sorry, I am a bit delayed with the review, I will do it by the end of this week. Apologies for the delay.

@mush42
Copy link
Contributor Author

mush42 commented Sep 27, 2023

No problem, take your time.

Best

@shivammehta25 shivammehta25 added the enhancement New feature or request label Sep 29, 2023
Copy link
Owner

@shivammehta25 shivammehta25 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thank you very much for this.

@shivammehta25 shivammehta25 merged commit 51ea36d into shivammehta25:dev Sep 29, 2023
@mush42
Copy link
Contributor Author

mush42 commented Sep 30, 2023

Glad to be of help. Let's take a swigg of that tasty Matcha 🙂

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ONNX Export
3 participants