Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Conv-TasNet] Facing issue in converting Conv-TasNet model #447

Closed
ajithvcoder opened this issue Aug 10, 2023 · 6 comments
Closed

[Conv-TasNet] Facing issue in converting Conv-TasNet model #447

ajithvcoder opened this issue Aug 10, 2023 · 6 comments
Labels
OP:PRelu OP:PRelu OP:ScatterElements OP:ScatterElements third party Third-party tool issues

Comments

@ajithvcoder
Copy link

ajithvcoder commented Aug 10, 2023

Issue Type

Others

OS

Linux

onnx2tf version number

1.15.7

onnx version number

1.14.0

onnxruntime version number

1.15.1

onnxsim (onnx_simplifier) version number

0.4.33

tensorflow version number

2.13.0

Download URL for ONNX

https://drive.google.com/file/d/189UHTs9OvDiNBc6BiZDG5zde2zSyTe6E/view?usp=sharing

Parameter Replacement JSON

{
    "format_version": 1,
    "operations": [
      {
        "op_name": "/decoder/Reshape",
        "param_target": "inputs",
        "param_name": "concat",
        "post_process_transpose_perm": [0,1,2,3,4]
      }
    ]
  }

Description

  1. Personal development
scatter_ele

Error:

tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __wrapped__ConcatV2_N_4_device_/job:localhost/replica:0/task:0/device:CPU:0}} ConcatOp : Ranks of all input tensors should match: shape[0] = [256,2,2,10,1] vs. shape[2] = [2,1,256,2,2,10,1] [Op:ConcatV2] name: concat

ERROR: input_onnx_file_path: conv_tasnet.onnx
ERROR: onnx_op_name: /decoder/ScatterElements
  1. I tried to use parameter replacement.json but its of no luck i am not sure how to use it to resolve this problem.
    Command that i tried
onnx2tf -i conv_tasnet.onnx -kat mixtures -cotof -cotoa 1e-4 -prf replace_conv_tasnet.json 
  1. I am working on audio separation process for which i need tflite models
  2. Repo link : https://github.com/kaituoxu/Conv-TasNet

Onnx conversion script:

import onnx
import torch

from conv_tasnet import ConvTasNet


def convertoOnnx():
    device = torch.device('cpu')
    # Create model.
    model = ConvTasNet(256, 20, 256, 512, 3, 8, 4,
                       2, norm_type="gLN", causal=0,
                       mask_nonlinear="relu")

    model.to(device)
    dummy_input =  {'mixtures': torch.ones(256, 20).to(torch.device('cpu'))}

    onnx_model_path = 'conv_tasnet.onnx'
    torch.onnx.export(model, dummy_input["mixtures"], onnx_model_path, verbose=True, opset_version=12)
def main():
    convertoOnnx()

if __name__ == "__main__":

    main()
@PINTO0309
Copy link
Owner

PINTO0309 commented Aug 10, 2023

It must be a bug in PyTorch. To begin with, you cannot inference with the standard onnxruntime. Have you checked?

sit4onnx -if conv_tasnet.onnx -oep cpu

INFO: file: conv_tasnet.onnx
INFO: providers: ['CPUExecutionProvider']
INFO: input_name.1: onnx::Unsqueeze_0 shape: [256, 20] dtype: float32
Traceback (most recent call last):
  File "/home/b920405/.local/bin/sit4onnx", line 8, in <module>
    sys.exit(main())
  File "/home/b920405/.local/lib/python3.10/site-packages/sit4onnx/onnx_inference_test.py", line 506, in main
    final_results = inference(
  File "/home/b920405/.local/lib/python3.10/site-packages/sit4onnx/onnx_inference_test.py", line 357, in inference
    results = onnx_session.run(
  File "/home/b920405/.local/lib/python3.10/site-packages/onnxruntime/capi/onnxruntime_inference_collection.py", line 217, in run
    return self._sess.run(output_names, input_feed, run_options)
onnxruntime.capi.onnxruntime_pybind11_state.InvalidArgument: [ONNXRuntimeError] : 2 :

INVALID_ARGUMENT : Non-zero status code returned while running ScatterElementsnode.
Name: '/decoder/ScatterElements'
Status Message: Indices and updates must have the same rank

Embedded indicies that ignore the specifications in the ONNX OP specification. There is no obligation to convert ONNX files that ignore the specifications.

@ajithvcoder
Copy link
Author

thanks for the info

@ajithvcoder
Copy link
Author

ajithvcoder commented Aug 17, 2023

subject : Facing Nan values in tflite output
@PINTO0309 i some how managed to remove scaterelements node from the model so now i am able to run properly in onnx standard runtime and below is the code which i used to verify it and its successful in verification.

import onnx
import torch
import torch.nn as nn
import torch.nn.init as init
from conv_tasnet import ConvTasNet
import onnxruntime
import numpy as np


def convertoOnnx():

    device = torch.device('cpu')

    
    # Create the super-resolution model by using the above model definition.
    model = ConvTasNet(256, 20, 256, 512, 3, 8, 4,
                       2, norm_type="gLN", causal=0,
                       mask_nonlinear="softmax")
    model.eval()
    model.to(device)
    dummy_input = torch.ones(256, 20).to(torch.device('cpu'))

    # Export the model
    torch.onnx.export(model,               # model being run
                    dummy_input,                         # model input (or a tuple for multiple inputs)
                    "conv_tasnet_39nx_7.onnx",   # where to save the model (can be a file or file-like object)
                    export_params=True,        # store the trained parameter weights inside the model file
                    opset_version=12,          # the ONNX version to export the model to
                    do_constant_folding=True,  # whether to execute constant folding for optimization
                    input_names = ['input'],   # the model's input names
                    output_names = ['output'], # the model's output names
                    #dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                    #                'output' : {0 : 'batch_size'}}
                                    )
    #from conv_tasnet import ConvTasNet
    ort_session = onnxruntime.InferenceSession("conv_tasnet_39nx_7.onnx")

    def to_numpy(tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

    x = torch.randn(256, 20).to(torch.device('cpu'))

    device = torch.device('cpu')

    torch_out = model(x)

    # compute ONNX Runtime output prediction
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
    ort_outs = ort_session.run(None, ort_inputs)

    # compare ONNX Runtime and PyTorch results
    np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
    print("torch out")
    print(to_numpy(torch_out))
    print("onnx out")
    print(ort_outs[0])
    print("Exported model has been tested with ONNXRuntime, and the result looks good!")

def main():

    convertoOnnx()

if __name__ == "__main__":

    main()

I use below onnx model for conversion with onnx2tf and it also has the tflite model after conversion
https://drive.google.com/file/d/1vzo9P5lEaLLNBra4mKVhYe_5mCREEE6J/view?usp=sharing

this is the command i used to convert and there were no errors during the conversion onnx2tf -i conv_tasnet_39nx_7.onnx -kat mixtures -cotof -cotoa 1e-4 but when i do inference with tflite model i am getting "nan" values in the matrix of 250,2,20

import tensorflow as tf
import time


# Load TFLite model
interpreter = tf.lite.Interpreter(model_path="./saved_model/conv_tasnet_39nx_7_float32.tflite")
interpreter.allocate_tensors()
tensor_shape = (256, 20)
input_data = {'waveform': tf.ones(tensor_shape, dtype=tf.float32) }

# Load and preprocess
input_details = interpreter.get_input_details()
input_shape = input_details[0]['shape']
print(input_shape)

# Run inference
interpreter.set_tensor(input_details[0]['index'], input_data["waveform"])
separate_time = time.time()
interpreter.invoke()
print("Done! {:.3f} s".format(time.time() - separate_time))
output_details = interpreter.get_output_details()
output_data = interpreter.get_tensor(output_details[0]['index'])

# Check the device used for inference
# output_tensor = interpreter.tensor(output_data)
# print("Inference was performed on:", output_details.device_name)
output_data = []
for output_detail in output_details:
    output_data.append(interpreter.get_tensor(output_detail['index']))

print(output_data)

@PINTO0309
Copy link
Owner

PINTO0309 commented Aug 17, 2023

Do as much research on your own as you can. I am not a handyman.

import tensorflow as tf
import time
import numpy as np
np.random.seed(0)

# Load TFLite model
interpreter = tf.lite.Interpreter(model_path="./saved_model/conv_tasnet_39nx_7_float32.tflite")
interpreter.allocate_tensors()
tensor_shape = (256, 20)
input_data = {'waveform': np.random.randn(*tensor_shape).astype(np.float32)}

# Load and preprocess
input_details = interpreter.get_input_details()
input_shape = input_details[0]['shape']
print(input_shape)

# Run inference
interpreter.set_tensor(input_details[0]['index'], input_data["waveform"])
separate_time = time.time()
interpreter.invoke()
print("Done! {:.3f} s".format(time.time() - separate_time))
output_details = interpreter.get_output_details()
output_data = interpreter.get_tensor(output_details[0]['index'])

# Check the device used for inference
# output_tensor = interpreter.tensor(output_data)
# print("Inference was performed on:", output_details.device_name)
output_data = []
for output_detail in output_details:
    output_data.append(interpreter.get_tensor(output_detail['index']))

print(output_data)
onnx2tf \
-i conv_tasnet_39nx_7.onnx \
-cotof \
-onimc "/separator/network/network.2/network.2.0/network.2.0.0/net/net.1/PRelu_output_0"
[array([[[-0.00417661,  0.13557428,  0.00747245, ..., -0.        ,
          0.08287186,         nan]],

       [[-0.01359006,  0.14600608,  0.03704544, ..., -0.        ,
          0.13754255,         nan]],

       [[-0.01396633,         nan, -0.        , ..., -0.        ,
         -0.        ,         nan]],

       ...,

       [[-0.02858675,  0.03051723, -0.        , ...,  0.00328746,
          0.00181477,         nan]],

       [[-0.00639979,  0.17624335,  0.02880344, ..., -0.        ,
          0.02940094,         nan]],

       [[-0.0115342 ,  0.07608917,  0.0747497 , ..., -0.        ,
          0.08765636,         nan]]], dtype=float32)]
onnx2tf \
-i conv_tasnet_39nx_7.onnx \
-cotof \
-onimc "/separator/network/network.1/Conv_output_0"
[array([[[-0.17437564, -0.1193818 , -0.08993477, ...,  0.00564412,
          0.22582152,  0.171181  ]],

       [[-0.01511353, -0.17419885, -0.20137453, ...,  0.00773175,
          0.11049376,  0.11550873]],

       [[ 0.1850396 , -0.11355587, -0.19794238, ..., -0.09272984,
          0.05970074,  0.22245766]],

       ...,

       [[ 0.11041166, -0.0618556 , -0.08706705, ...,  0.04628651,
          0.0985143 ,  0.25410685]],

       [[ 0.12399255, -0.12580661, -0.09162924, ...,  0.13174479,
          0.09035407,  0.1888331 ]],

       [[ 0.13371539,  0.05764169, -0.09615928, ..., -0.02483038,
          0.0715657 ,  0.17773572]]], dtype=float32)]
onnx2tf \
-i conv_tasnet_39nx_7.onnx \
-cotof \
-onimc "/separator/network/network.2/network.2.0/network.2.0.0/net/net.0/Conv_output_0"
[array([[[-0.01670645,  0.13557428,  0.00747245, ..., -0.14510942,
          0.08287186, -0.13191386]],

       [[-0.05436026,  0.14600608,  0.03704544, ..., -0.01873946,
          0.13754255, -0.18560407]],

       [[-0.05586533, -0.08026854, -0.14332882, ..., -0.05465879,
         -0.0857949 , -0.17385857]],

       ...,

       [[-0.11434701,  0.03051723, -0.00572547, ...,  0.00328746,
          0.00181477, -0.01100029]],

       [[-0.02559917,  0.17624335,  0.02880344, ..., -0.01511031,
          0.02940094, -0.10401183]],

       [[-0.04613681,  0.07608917,  0.0747497 , ..., -0.09835263,
          0.08765636, -0.06544452]]], dtype=float32)]
onnx2tf \
-i conv_tasnet_39nx_7.onnx \
-cotof \
-onimc "/separator/network/network.2/network.2.0/network.2.0.0/net/net.1/PRelu_output_0" \
-rtpo PReLU
[array([[[-0.00417661,  0.13557428,  0.00747245, ..., -0.03627735,
          0.08287186, -0.03297846]],

       [[-0.01359006,  0.14600608,  0.03704544, ..., -0.00468486,
          0.13754255, -0.04640102]],

       [[-0.01396633, -0.02006713, -0.0358322 , ..., -0.0136647 ,
         -0.02144873, -0.04346464]],

       ...,

       [[-0.02858675,  0.03051723, -0.00143137, ...,  0.00328746,
          0.00181477, -0.00275007]],

       [[-0.00639979,  0.17624335,  0.02880344, ..., -0.00377758,
          0.02940094, -0.02600296]],

       [[-0.0115342 ,  0.07608917,  0.0747497 , ..., -0.02458816,
          0.08765636, -0.01636113]]], dtype=float32)]

There must be a bug in the PReLU of the TFLite runtime. I do not check the source code of every single TFLite runtime. Alternatively, it is highly possible that such implementation is intentional.

onnx2tf -i conv_tasnet_39nx_7.onnx -cotof -rtpo PReLU
[array([[[-0.02121598, -0.01494515,  0.03767578, ...,  0.04305519,
         -0.03268029, -0.00984246],
        [-0.01842985, -0.02896241,  0.00032737, ...,  0.03822319,
         -0.03688553, -0.039774  ]],

       [[-0.09389048,  0.02472125,  0.06904005, ..., -0.04805087,
         -0.05356564,  0.0098032 ],
        [-0.0654762 ,  0.06695021,  0.06765427, ..., -0.08847025,
         -0.03741876, -0.02165272]],

       [[-0.10159367, -0.04728557, -0.00644698, ..., -0.00614144,
         -0.00697455, -0.02969793],
        [-0.10006553, -0.03004916, -0.04499952, ..., -0.0368793 ,
          0.01113988, -0.05985563]],

       ...,

       [[-0.03429771, -0.10090043, -0.0561833 , ..., -0.03551175,
          0.02616569, -0.00240966],
        [-0.01021105, -0.08929865, -0.06661414, ..., -0.03666518,
          0.03669975, -0.01670682]],

       [[-0.02011088, -0.02259451,  0.0043005 , ...,  0.01583091,
         -0.01501969,  0.03965813],
        [-0.03771172, -0.02292055, -0.02259337, ...,  0.00371877,
          0.00514893,  0.00786439]],

       [[-0.04744648,  0.04065312, -0.00950121, ..., -0.0026527 ,
          0.00029717,  0.00457481],
        [-0.02485752,  0.02383837, -0.00040416, ...,  0.0262814 ,
         -0.00458154, -0.00789627]]], dtype=float32)]

image

prelu_check.onnx.zip

@PINTO0309
Copy link
Owner

I decided to add to the README the numerous workarounds I have implemented in this tool, as they do not seem to be understood by most engineers.

@ajithvcoder
Copy link
Author

thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
OP:PRelu OP:PRelu OP:ScatterElements OP:ScatterElements third party Third-party tool issues
Projects
None yet
Development

No branches or pull requests

2 participants