Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Emo model onnx improvements #19

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

leoromanovich
Copy link

@leoromanovich leoromanovich commented Dec 20, 2024

Hi, @georgygospodinov!

I've tried to convert to onnx and run Emo model, but found the following problems which I've fixed in this PR.

  • strange representation of 1d avg pool + linear layer in onnx graph was found so inference for that graph is not working because AvgPool -> Gemm array shapes mismatch (attached pics)
  • no onnx inference for Emo model was found too
image

to address these problems I'm offering the following changes:

  • replace avg pool with kernel_size == -1 dim to simple toch.mean over this dimention (seem's like, in that context operations are equivalent and result is not changing)
  • add recognise_emotion method to infer onnx version of emo model
  • add example of emo model inference to inference_example.ipynb

@georgygospodinov , may I ask you to review my PR?
I would appreciate any comments, ty!

Minimal code sample to reproduce graph problem:

import torch
import onnxruntime as rt

import gigaam
from gigaam.onnx_utils import DTYPE, FEAT_IN, SAMPLE_RATE

if __name__ == "__main__":

    opts = rt.SessionOptions()
    opts.intra_op_num_threads = 16
    opts.execution_mode = rt.ExecutionMode.ORT_SEQUENTIAL

    session = rt.InferenceSession(
        "path_to_onnx_checkpoint/v1_emo.onnx", 
        providers=["CPUExecutionProvider"],
        sess_options=opts,
        )

    preprocessor = gigaam.preprocess.FeatureExtractor(SAMPLE_RATE, FEAT_IN)

    input_signal = gigaam.load_audio(str("path_to_example_file/example.wav")).unsqueeze(0)
    input_signal = preprocessor(
        input_signal,
        torch.tensor([input_signal.shape[-1]]),
        )[0].numpy()

    enc_inputs = {
        node.name: data
        for (node, data) in zip(
            session.get_inputs(),
            [input_signal.astype(DTYPE), [input_signal.shape[-1]]],
            )
        }

    enc_features = session.run(
        [node.name for node in session.get_outputs()], enc_inputs
        )[0]

    print(enc_features)

Output:

2024-12-20 13:50:00.905064261 [E:onnxruntime:, sequential_executor.cc:514 ExecuteKernel] Non-zero status code returned while running Gemm node. Name:'/head/Gemm' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/gemm_helper.h:17 onnxruntime::GemmHelper::GemmHelper(const onnxruntime::TensorShape&, bool, const onnxruntime::TensorShape&, bool, const onnxruntime::TensorShape&) left.NumDimensions() == 2 || left.NumDimensions() == 1 was false.

Traceback (most recent call last):
  File "/mnt/original_repo/GigaAM/infer_model.py", line 35, in <module>
    enc_features = session.run(
                   ^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 220, in run
    return self._sess.run(output_names, input_feed, run_options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
onnxruntime.capi.onnxruntime_pybind11_state.RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Gemm node. Name:'/head/Gemm' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/gemm_helper.h:17 onnxruntime::GemmHelper::GemmHelper(const onnxruntime::TensorShape&, bool, const onnxruntime::TensorShape&, bool, const onnxruntime::TensorShape&) left.NumDimensions() == 2 || left.NumDimensions() == 1 was false.

Operations are equivalent in this context, because originally we use kernel size == -1 dim of the tensor.
refactor wrt to DRY part with loading and preprocessing wav file and infer through encoder.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant