Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Concat argument element types are inconsistent #4745

Closed
3 tasks done
addisonklinke opened this issue Mar 11, 2021 · 19 comments
Closed
3 tasks done

Concat argument element types are inconsistent #4745

addisonklinke opened this issue Mar 11, 2021 · 19 comments
Assignees
Labels
bug Something isn't working PSE support_request

Comments

@addisonklinke
Copy link

addisonklinke commented Mar 11, 2021

System information (version)
  • OpenVINO: 2021.2.185
  • Operating System / Platform: Ubuntu 18.04
  • Compiler: NA
  • Problem classification: Model inference
  • Framework: ONNX
Detailed description

The model optimizer successfully converts my ONNX model to the OpenVINO IR representation (.xml and .bin) - there are no errors or warnings. However, the Python binding for the inference engine fails to even load the network. Since the IR files are produced without issue, I expect the network to load without error at a minimum.

Steps to reproduce

UPDATE March 16: I added code in a later comment to generate a very basic (5 node) ONNX graph that will reproduce the error

  1. Start with a fresh, official Ubuntu 18.04 Docker container
  2. Follow the official OpenVINO installation instructions - I executed setupvars.sh, installed ONNX prerequisites, and successfully ran the SqueezeNet demo
  3. Run the model optimizer command: python3 mo.py --input_model model.onnx --progress
  4. Attempt to load the model in Python
from openvino.inference_engine import IECore, IENetwork

ie = IECore()
net = ie.read_network(model='model.xml', weights='model.bin')

The last line fails with

Traceback (most recent call last):
  File "minimal.py", line 4, in <module>
    net = ie.read_network(model='model.xml', weights='model.bin')
  File "ie_api.pyx", line 261, in openvino.inference_engine.ie_api.IECore.read_network
  File "ie_api.pyx", line 285, in openvino.inference_engine.ie_api.IECore.read_network
RuntimeError: Check 'element::Type::merge(inputs_et, inputs_et, get_input_element_type(i))' failed at ngraph/core/src/op/concat.cpp:60:
While validating node 'v0::Concat Concat_470 (413[0]:i64{?,1}, 469/Unsqueeze[0]:i32{?,1}) -> (dynamic?)' with friendly_name 'Concat_470':
Argument element types are inconsistent.

The traceback indicates that the two input nodes of Concat_470 have inconsistent dtypes int32 and int64. However, if I open model.xml with Netron, this is not true. The output of both nodes 469/Unsqueeze and 413[0] (shown below) are consistently int64 as expected. The dtype agreement is also correct in the original ONNX graph. It appears that only the OpenVINO inference engine is somehow misinterpreting the dtypes

The only thing odd in the XML graph is that the concat node causing the error lists its input types as ?
Note input 43:2 comes from 413/Multiply[0] and input 574:2 comes from 469/Unsqueeze

This leads me to believe that the inference engine is inferring the int32 dtype that causes the error rather than knowing for sure. If that is the root issue, I am unclear why...

  1. The concat node doesn't know the dtypes explicitly. Both inputs are defined as int64 in the preceding nodes, so there should be no question what the concat dtypes should be when the XML graph is constructed
  2. OpenVINO generates an XML IR with unknown (?) dtypes instead of raising on error on the conversion from ONNX
  3. OpenVINO defaults to int32 when inferring dtype
Issue submission checklist
  • I report the issue, it's not a question
    • Yes, OpenVINO should definitely not crash when trying to load a model that it generated
  • I checked the problem with documentation, FAQ, open issues, Stack Overflow, etc and have not found solution
  • There is reproducer code and related data files: images, videos, models, etc.
    • See comment below for complete code + commands describing a PyTorch > ONNX > OpenVINO conversion pipeline that reproduces the bug
@addisonklinke addisonklinke added bug Something isn't working support_request labels Mar 11, 2021
@Iffa-Intel Iffa-Intel assigned Iffa-Intel and unassigned Iffa-Intel Mar 12, 2021
@Iffa-Intel Iffa-Intel self-assigned this Mar 15, 2021
@Iffa-Intel Iffa-Intel added ONNX Related to support for ONNX standard. and removed bug Something isn't working labels Mar 15, 2021
@lazarevevgeny
Copy link
Contributor

@mvafin, please, take a look.

@addisonklinke
Copy link
Author

addisonklinke commented Mar 15, 2021

@Iffa-Meah @lazarevevgeny @mvafin I have created a minimal PyTorch/ONNX model that will allow you to reproduce the bug without relying on my more complex custom model. Please see the attached ONNX file (or use the script below to generate it)

import numpy as np
import onnx
import onnxruntime
import torch
import torch.onnx
import torch.nn as nn


def to_numpy(tensor):
    """Helper function recommended in official PyTorch ONNX tutorial

    See details below
    https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
    """
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


class ConcatModel(nn.Module):

    def __init__(self, vocab_size=10, seq_len=3):
        """Setup random character probabilities

        Do this only once in the init so that calling the model multiple
        times (whether from PyTorch or ONNX) will produce identical results.

        :param int vocab_size: Number of unique character possibilities
        :param int seq_len: Maximum number of characters in the sequence
        """
        super(ConcatModel, self).__init__()
        self.seq_len = seq_len
        self.char_probs = [torch.randn(vocab_size) for _ in range(seq_len)]

    def forward(self, sos_token):
        """Simplified generation for sequence of character IDs

        :param torch.Tensor sos_token: Index representing start-of-sequence
            token. Must be a tensor because JIT trace for the ONNX export
            doesn't support int inputs. Expected shape ``(1, )``
        :return torch.LongTensor seq: Complete sequence of character indices
        """
        seq = sos_token
        for i in range(self.seq_len):
            new_idx = torch.argmax(self.char_probs[i]).unsqueeze(0)
            seq = torch.cat([seq, new_idx], dim=0)
        return seq


if __name__ == '__main__':

    # Set constants
    torch.manual_seed(1234)
    sos_token = torch.tensor([-1])
    model_path = 'concat_model.onnx'

    # Run PyTorch inference and export to ONNX
    model = ConcatModel()
    model.eval()
    torch_out = model(sos_token)
    torch.onnx.export(
        model=model,
        args=sos_token,
        f=model_path,
        opset_version=9,
        do_constant_folding=True,
        input_names=['sos_token'],
        output_names=['seq'])

    # Load/check the ONNX model and run through Python API
    onnx_model = onnx.load(model_path)
    onnx.checker.check_model(onnx_model)
    ort_session = onnxruntime.InferenceSession(model_path)
    ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(sos_token)}
    ort_outs = ort_session.run(None, ort_inputs)
    np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

After download/producing the ONNX file, convert it to OpenVINO

python3 mo.py --input_model concat_model.onnx

And reproduce the error when trying to load into the inference engine

from openvino.inference_engine import IECore, IENetwork

ie = IECore()
net = ie.read_network(model='concat_model.xml', weights='concat_model.bin')

The error traceback is very similar to my custom model

Traceback (most recent call last):
  File "minimal.py", line 6, in <module>
    net = ie.read_network(model='concat_model.xml', weights='concat_model.bin')
  File "ie_api.pyx", line 261, in openvino.inference_engine.ie_api.IECore.read_network
  File "ie_api.pyx", line 285, in openvino.inference_engine.ie_api.IECore.read_network
RuntimeError: Check 'element::Type::merge(inputs_et, inputs_et, get_input_element_type(i))' failed at ngraph/core/src/op/concat.cpp:60:
While validating node 'v0::Concat Concat_1 (sos_token[0]:i32{1}, Constant_0/Output_0/Data__const[0]:i64{1}) -> ()' with friendly_name 'Concat_1':
Argument element types are inconsistent.

One of the two concat inputs is int32 while the other is int64. In the ONNX graph, everything is int64 throughout but in the XML sos_token starts out as int32. I noticed the model optimizer's input flag can be used to specify dtype, so I changed my command to

python3 mo.py --input_model concat_model.onnx --input sos_token[1]{i64}

With this modification, the inference engine is able to load the resulting XML without any error (as desired). However, I still believe this is a bug because the ONNX model explicitly defines sos_token as int64 and OpenVINO changes this without any notice. Also with my real model, the input node to concat is not a top-level input to the model, so I cannot provide its dtype to the model optimizer on the command line. Is there a patch I could add in the model optimizer source code that would allow my model to recognize its dtype properly?

@Iffa-Intel
Copy link

I believe this "Argument element types are inconsistent" Error happens because you didn't specifically specify your scale values and means values especially during the ONNX to IR conversion

You may take this as example.
In model optimizer, the command to convert his model should be like: mo.py -m resnet18.onnx --scale_values=[58.395,57.120,57.375] --mean_values=[123.675,116.28,103.53] --reverse_input_channels --disable_resnet_optimization --disable_fusing --disable_gfusing --data_type=FP32 --output_dir fp32

A simple import and the implementation of Executable Network object also works for his converted model.
load_model

@addisonklinke
Copy link
Author

My full model is designed to include the image scaling and mean values inside the ONNX graph, so adding those as CLI flags would duplicate the transformation. I tested your suggestion anyways, and it does not make a difference for my full model - the error is still the same.

In my full model, the concat nodes occur many operations into the graph, so that is why they are not affected by providing the mean/scale values options. We need to understand why the unsqueeze node (which is explicitly int64 in the ONNX graph) gets converted to int32 by OpenVINO. My guess is the model optimizer code has a hard-coded preference for int32 somewhere which is overriding the int64 definition provided by ONNX

@addisonklinke
Copy link
Author

addisonklinke commented Mar 16, 2021

To further troubleshoot my full model, I added an explicit cast layer in the ONNX graph prior to the troublesome concat

I expected this to resolve any uncertainty about the incoming concat dtype, however OpenVINO's model optimizer removes my cast node and continues to leave the concat inputs' dtypes as ?

The error from the inference engine is still the same (I believe the nodes names only changed because of my ONNX modifications to add a cast node)

Traceback (most recent call last):
  File "minimal.py", line 10, in <module>
    net = ie.read_network(model=f'{args.model}.xml', weights=f'{args.model}.bin')
  File "ie_api.pyx", line 261, in openvino.inference_engine.ie_api.IECore.read_network
  File "ie_api.pyx", line 285, in openvino.inference_engine.ie_api.IECore.read_network
RuntimeError: Check 'element::Type::merge(inputs_et, inputs_et, get_input_element_type(i))' failed at ngraph/core/src/op/concat.cpp:60:
While validating node 'v0::Concat Concat_2847 (Mul_104[0]:i64{?,1}, Unsqueeze_154/Unsqueeze[0]:i32{?,1}) -> (dynamic?)' with friendly_name 'Concat_2847':
Argument element types are inconsistent.

Can anyone explain why my cast operation is being ignored? Regardless of what I try, OpenVINO is not respecting the dtypes that are specified (very explicitly) in my ONNX graph

@lazarevevgeny
Copy link
Contributor

@addisonklinke , not all OpenVINO plugins natively support int64 so when we have operations producing int64 values on a data path (not ShapeOf sub-graphs which are const-folded before passing the model to the plugin) MO converts them to int32. This is why the explicit cast to int64 is removed from the model.

But when you explicitly specified the input type using "--input sos_token[1]{i64}" parameter you override the default behaviour of the MO and the IR is generated with Parameter of type int64. However, this IR will not work for some OpenVINO plugins because they don't support int64 natively.

@addisonklinke
Copy link
Author

Thank you for that clarification. Is there a way I can override the default behavior for non-input nodes and generate an IR with type int64 (even if it involves modifying the MO source code)? The target hardware for this model in Movidius VPU - if int64 plugins are not supported by that device would the inference engine be able to fallback to CPU execution for certain nodes?

In terms of precision, I have no problem running all the model nodes on int32. The issue is that PyTorch's nn.Embedding layer does not support int32, so I am forced to used int64. Many users have requested int32 support for that layer on this issue thread, but unfortunately the PyTorch team has not yet implemented that feature.

@lazarevevgeny
Copy link
Contributor

Looks like you need to comment the following lines: https://github.com/openvinotoolkit/openvino/blob/master/model-optimizer/extensions/front/ChangePlaceholderTypes.py#L52-L55

Unfortunately, I cannot answer whether the fallback to CPU will work in this case.

@addisonklinke
Copy link
Author

addisonklinke commented Mar 18, 2021

Shortly after my comment yesterday, PyTorch clarified that the int32 issue has been resolved on their end in the latest 1.8.0 release. By upgrading PyTorch and tweaking my model, I was able to produce an IR that is loaded by the OpenVino inference engine without any issue on CPU. After seeing this issue play out, I believe the previous error when using int64 should be raised by the model optimizer rather than the inference engine. From my perspective, there is no point in producing an IR that is known to fail with the inference engine, so raising the error earlier in the process would be more intuitive for end users. Of course, maybe the OpenVino developers disagree with me, but this is my suggestion 🙂

As a final step, I am working to load and run the full model on Movidius VPU, but the same ie.load_network produces a new error which I notice comes from the source code in model.cpp

RuntimeError: duplicateData error: while duplicating Conv_110/reshape_begin 
Const data got different desc and content byte sizes (1500 and 300 respectively)

If I add ie.set_config({'VPU_HW_STAGES_OPTIMIZATION': 'NO'}, 'MYRIAD') the error goes away, but performance is prohibitively slow without the optimizations. Are you able to provide some insight into the duplicate data error? I have identified the corresponding node in my graph, but do not understand the meaning of the desc and content variables as used in the source code. Looking at mo/front/common/custom_replacement_registry.py:55 I think desc is probably description, so the dimension size of the data times its dtype size is not equal to the actual content size?

@addisonklinke addisonklinke changed the title [Bug] Concat argument element types are inconsistent Concat argument element types are inconsistent Mar 18, 2021
@lazarevevgeny lazarevevgeny added category: VPU and removed ONNX Related to support for ONNX standard. labels Mar 22, 2021
@lazarevevgeny
Copy link
Contributor

@addisonklinke , I agree that the MO should not generate the IR in case of types mismatch, but for now we cannot do this unfortunately.

@taka-no-me , @gladilov-gleb , could you take a look at the VPU issue? The CPU works fine for the model.

@addisonklinke
Copy link
Author

@taka-no-me @gladilov-gleb I have created a minimal reproducible example for the VPU optimization error. It seems to occur whenever the graph uses a constant to initialize a sequence of operations. Please use my Python script below to produce two ONNX variants: implicit.onnx and explicit.onnx

import torch
import torch.onnx
import torch.nn as nn


class Simple(nn.Module):
    def __init__(self, img_size):
        super(Simple, self).__init__()
        self.img_size = img_size
        self.conv = nn.Conv2d(3, 16, kernel_size=3)

    def forward(self, a, b=None):
        if b is None:
            b = torch.ones(1, 3, self.img_size, self.img_size)
        feat_b = self.conv(b)
        feat_a = self.conv(a)
        return feat_a + feat_b



if __name__ == '__main__':

    img_size = 24
    model = Simple(img_size)
    model.eval()

    template = torch.randn(1, 3, img_size, img_size)
    implicit_b = template
    explicit_b = (template, template)
    torch.onnx.export(model, implicit_b, 'implicit.onnx')
    torch.onnx.export(model, explicit_b, 'explicit.onnx')

Then convert both to the OpenVino IR with python3 mo.py --input_model *.onnx. Afterwards, run the test script below on a machine with Movidius NCS2 plugged in

from argparse import ArgumentParser
from openvino.inference_engine import IECore

parser = ArgumentParser(description='Test VPU inference with OpenVino IR files')
parser.add_argument('-m', '--model', type=str, required=True, help='common path prefix of IR files')
args = parser.parse_args()

print(f'Loading {args.model}')
model_xml = args.model + '.xml'
model_bin = args.model + '.bin'

ie = IECore()
net = ie.read_network(model=model_xml, weights=model_bin, init_from_buffer=False)
print('Read network')
exec_net = ie.load_network(network=net, num_requests=1, device_name='MYRIAD')
print('Loaded network to VPU')

You should see

$ python3 test.py -m explicit
Loading explicit
Read network
Loaded network to VPU

$ python3 test.py -m implicit
Loading implicit
Read network
Traceback (most recent call last):
  File "minimal.py", line 15, in <module>
    exec_net = ie.load_network(network=net, num_requests=1, device_name='MYRIAD')
  File "ie_api.pyx", line 306, in openvino.inference_engine.ie_api.IECore.load_network
  File "ie_api.pyx", line 315, in openvino.inference_engine.ie_api.IECore.load_network
RuntimeError: duplicateData error: while duplicating Constant_0/Output_0/Data__const 
Const data got different desc and content byte sizes (4608 and 3456 respectively)

@addisonklinke
Copy link
Author

I also tried registering a PyTorch buffer as the default value for b, but that still produces the same error on the implicit model

class Simple(nn.Module):
    def __init__(self, img_size):
        super(Simple, self).__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3)
        self.register_buffer('default_b', torch.ones(1, 3, img_size, img_size))

    def forward(self, a, b=None):
        if b is None:
            b = self.default_b
        feat_b = self.conv(b)
        feat_a = self.conv(a)
        return feat_a + feat_b

@ggladilov
Copy link
Contributor

@Maxim-Doronin please take a look

@addisonklinke
Copy link
Author

@Maxim-Doronin any updates on this?

@Maxim-Doronin
Copy link
Contributor

Maxim-Doronin commented Apr 16, 2021

@addisonklinke, we've added this issue into our sprint. We will look at this as soon as possible

Ref. 51088

@Iffa-Intel Iffa-Intel added PSE bug Something isn't working labels Apr 20, 2021
@ilyachur
Copy link
Contributor

@Maxim-Doronin Do we have any progress for this issue?

@addisonklinke
Copy link
Author

addisonklinke commented Jul 14, 2021

@Maxim-Doronin I see that OpenVino has two new releases since the start of this issue

  • 2021.3.0 (March 23, 2021)
  • 2021.4.0 (June 29, 2021)

Do either of them have the bug fix from the sprint you mentioned which would address my VPU issue? If so, I can test and close this issue if everything looks good

EDIT: tested in an Ubuntu 18.04 Docker container with model optimizer 2021.4.582 and inference engine 2021.4.0-3839-cd81789d294 (from PyPI), and still getting the same error

$ python3 minimal.py -m implicit
Loading implicit
Read network
Traceback (most recent call last):
  File "minimal.py", line 16, in <module>
    exec_net = ie.load_network(network=net, num_requests=1, device_name=args.device)
  File "ie_api.pyx", line 367, in openvino.inference_engine.ie_api.IECore.load_network
  File "ie_api.pyx", line 379, in openvino.inference_engine.ie_api.IECore.load_network
RuntimeError: [ GENERAL_ERROR ] 
/home/jenkins/agent/workspace/private-ci/ie/build-linux-centos76/b/repos/openvino/inference-engine/src/vpu/graph_transformer/src/model/model.cpp:201 
duplicateData error: while duplicating Constant_0 
Const data got different desc and content byte sizes (4608 and 3456 respectively)

@jgespino
Copy link
Contributor

Hi @addisonklinke

Apologies for the delay in our response, the proper fix is not included in these two releases. It is currently planned for a future release but I cannot comment on the timing. As a workaround, you will need to specify the input as you mentioned above.
--input sos_token[1]{i64}

Regards,
Jesus
Ref. 55342

@jgespino
Copy link
Contributor

jgespino commented Nov 8, 2021

The bug has been fixed by this PR: #7630

@jgespino jgespino closed this as completed Nov 8, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working PSE support_request
Projects
None yet
Development

No branches or pull requests

7 participants