Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][ONNX] operator support NonZero #5073

Merged

Conversation

cchung100m
Copy link
Contributor

Hi @tqchen @vinx13 @zhiics

Following issue #4568, this PR is going to add NonZero operator and the test case for ONNX frontend. I would appreciate that if you can help me to review/manage this PR, thanks.

@tqchen
Copy link
Member

tqchen commented Mar 16, 2020

@kazum @masahi can you please followup?

@@ -30,7 +30,8 @@
import scipy


def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32',
opset=None, mode=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, why does mode need to be added? Does NonZero only work in the VM executor? If so, maybe we can just exclusively use the VM for testing, I'm not sure there's a need to support both.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, It makes sense and I will put it into another function instead of adding it to get_tvm_output.

@masahi masahi self-assigned this Mar 17, 2020
def _impl_v9(cls, inputs, attr, params):
if len(inputs) > 1:
raise ValueError("Expect 1 input only")
return AttrCvt(op_name='argwhere')(inputs, attr, params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output shape of argwhere is different from the one of ONNX nonzero. Needs transpose here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review and I updated it.


mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)

ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target=target)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tvm.cpu() should be ctx?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review and I updated it.

@cchung100m cchung100m force-pushed the support_operator_NonZero_in_frontend_onnx branch from 6c76ff8 to fdd8733 Compare March 17, 2020 12:08
return input_names, shape_dict


def get_input_data_dtype_dict(graph_def, input_data):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is not used anywhere.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kazum
Thanks for the review and I can remove it here.

return result.asnumpy()


def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using vm for all the tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kazum
I'd like to keep the relay.create_executor and relay.build both in this PR.

I cannot change the relay.build with relay.create_executor directly due to there are many errors like below:

  File "/tvm/tests/python/frontend/onnx/test_forward.py", line 2282, in <module>
    test_flatten()

  File "/tvm/tests/python/frontend/onnx/test_forward.py", line 374, in test_flatten
    tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')

  File "/tvm/tests/python/frontend/onnx/test_forward.py", line 70, in get_tvm_output
    result = ex.evaluate()(indata)

  File "/tvm/python/tvm/relay/backend/vm.py", line 256, in _vm_wrapper
    return self.vm.run(*args)

  File "/tvm/python/tvm/runtime/vm.py", line 366, in run
    return self.invoke("main", *args, **kwargs)

  File "/tvm/python/tvm/runtime/vm.py", line 348, in invoke
    return self._invoke(func_name)

  File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (7) 8   ???                                 0x00007fff54230930 0x0 + 140734604970288
  [bt] (6) 7   _ctypes.cpython-37m-darwin.so       0x00000001104dc2bf ffi_call_unix64 + 79
  [bt] (5) 6   libtvm.dylib                        0x0000000125071f78 TVMFuncCall + 72
  [bt] (4) 5   libtvm.dylib                        0x00000001250a4e3f std::__1::__function::__func<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0, std::__1::allocator<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 735
  [bt] (3) 4   libtvm.dylib                        0x00000001250a199a tvm::runtime::vm::VirtualMachine::RunLoop() + 7610
  [bt] (2) 3   libtvm.dylib                        0x00000001250a310f tvm::runtime::vm::VirtualMachine::InvokePacked(long long, tvm::runtime::PackedFunc const&, long long, long long, std::__1::vector<tvm::runtime::ObjectRef, std::__1::allocator<tvm::runtime::ObjectRef> > const&) + 1039
  [bt] (1) 2   libtvm.dylib                        0x000000012507b396 std::__1::__function::__func<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0, std::__1::allocator<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 310
  [bt] (0) 1   libtvm.dylib                        0x0000000124666af9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "/tvm/src/runtime/library_module.cc", line 89
TVMError: Check failed: ret == 0 (-1 vs. 0) : Assert fail: (((tvm_struct_get(arg0, 0, 5) == (uint8)2) && (tvm_struct_get(arg0, 0, 6) == (uint8)32)) && (tvm_struct_get(arg0, 0, 7) == (uint16)1)), arg0.dtype is expected to be float32

File "/tvm/tests/python/frontend/onnx/test_forward.py", line 2128, in verify_lstm
    output_dtype=['float32', 'float32', 'float32'])

  File "/tvm/tests/python/frontend/onnx/test_forward.py", line 69, in get_tvm_output
    indata = tvm.nd.array(input_data)

  File "/tvm/python/tvm/runtime/ndarray.py", line 487, in array
    return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)

  File "/tvm/python/tvm/runtime/ndarray.py", line 270, in empty
    dtype = DataType(dtype)

  File "/tvm/python/tvm/_ffi/runtime_ctypes.py", line 101, in __init__
    raise ValueError("Do not know how to handle type %s" % type_str)

ValueError: Do not know how to handle type object

Maybe we can initiate another PR to the above issues and change the relay.build with relay.create_executor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm fine with keeping both for now :)

@cchung100m cchung100m force-pushed the support_operator_NonZero_in_frontend_onnx branch from fdd8733 to c7e6b10 Compare March 18, 2020 15:51
@cchung100m
Copy link
Contributor Author

Hi @kazum

It seems that I need to add the schedule of argwhere for cuda in this PR. I would appreciate if you can guide me to do it, many thanks. :)


=================================== FAILURES ===================================

_________________________________ test_nonzero _________________________________



    def test_nonzero():

    

        def verify_nonzero(indata, outdata, dtype):

            node = helper.make_node('NonZero',

                                    inputs=['X'],

                                    outputs=['Y'],)

    

            graph = helper.make_graph([node],

                                      "nonzero_test",

                                      inputs=[helper.make_tensor_value_info("X", TensorProto.INT64, list(indata.shape))],

                                      outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, list(outdata.shape))])

    

            model = helper.make_model(graph, producer_name='nonzero_test')

    

            onnx_out = get_onnxruntime_output(model, indata, dtype)

    

            for target, ctx in ctx_list():

                tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=9)

                tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)

    

        input_data = np.array([[1, 0], [1, 1]], dtype=np.int64)

        result = np.array((np.nonzero(input_data)))  # expected output [[0, 1, 1], [0, 0, 1]]

>       verify_nonzero(input_data, result, dtype=np.int64)



tests/python/frontend/onnx/test_forward.py:2251: 

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

tests/python/frontend/onnx/test_forward.py:2246: in verify_nonzero

    tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=9)

tests/python/frontend/onnx/test_forward.py:54: in get_tvm_output_with_vm

    ex = relay.create_executor('vm', mod=mod, ctx=ctx, target=target)

python/tvm/relay/build_module.py:414: in create_executor

    return VMExecutor(mod, ctx, target)

python/tvm/relay/backend/vm.py:247: in __init__

    self.executable = compile(mod, target)

python/tvm/relay/backend/vm.py:68: in compile

    compiler.lower(mod, target, target_host)

python/tvm/relay/backend/vm.py:134: in lower

    self._lower(mod, target, target_host)

tvm/_ffi/_cython/./packed_func.pxi:308: in tvm._ffi._cy3.core.PackedFuncBase.__call__

    ???

tvm/_ffi/_cython/./packed_func.pxi:243: in tvm._ffi._cy3.core.FuncCall

    ???

tvm/_ffi/_cython/./packed_func.pxi:232: in tvm._ffi._cy3.core.FuncCall3

    ???

_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 



>   ???

E   tvm._ffi.base.TVMError: Traceback (most recent call last):

E     [bt] (8) /workspace/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::VisitExpr_(tvm::relay::CallNode const*)+0x3d5) [0x7f8cb0e7ef55]

E     [bt] (7) /workspace/build/libtvm.so(tvm::relay::OpMatch<void>::operator()(tvm::relay::Call const&)+0xef) [0x7f8cb0e7c27f]

E     [bt] (6) /workspace/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::VisitExpr_(tvm::relay::CallNode const*)::{lambda(tvm::Array<tvm::RelayExpr, void> const&, tvm::Attrs const&, tvm::Array<tvm::Type, void> const&)#1}::operator()(tvm::Array<tvm::RelayExpr, void> const&, tvm::Attrs const&, tvm::Array<tvm::Type, void> const&) const+0x13a) [0x7f8cb0e7e12a]

E     [bt] (5) /workspace/build/libtvm.so(tvm::relay::vm::VMFunctionCompiler::EmitInvokeTVMOp(tvm::relay::Function const&, tvm::RelayExpr const&, tvm::RelayExpr const&)+0x8f3) [0x7f8cb0e7d633]

E     [bt] (4) /workspace/build/libtvm.so(tvm::relay::CompileEngineImpl::Lower(tvm::relay::CCacheKey const&)+0x20) [0x7f8cb0e4ef20]

E     [bt] (3) /workspace/build/libtvm.so(tvm::relay::CompileEngineImpl::LowerInternal(tvm::relay::CCacheKey const&)+0x61e) [0x7f8cb0e4e17e]

E     [bt] (2) /workspace/build/libtvm.so(tvm::relay::ScheduleGetter::Create(tvm::relay::Function const&)+0xe8f) [0x7f8cb0e4d3ef]

E     [bt] (1) /workspace/build/libtvm.so(tvm::relay::OpImplementation::Schedule(tvm::Attrs const&, tvm::Array<tvm::te::Tensor, void> const&, tvm::Target const&)+0xb1) [0x7f8cb0e8f231]

E     [bt] (0) /workspace/build/libtvm.so(+0xc5e14b) [0x7f8cb0f8714b]

E     File "tvm/_ffi/_cython/./packed_func.pxi", line 54, in tvm._ffi._cy3.core.tvm_callback

E     File "/workspace/python/tvm/relay/op/strategy/generic.py", line 738, in schedule_argwhere

E       return topi.generic.schedule_argwhere(outs)

E     File "/workspace/topi/python/topi/generic/search.py", line 35, in schedule_argwhere

E       return _default_schedule(outs, False)

E     File "/workspace/topi/python/topi/generic/vision.py", line 29, in _default_schedule

E       raise RuntimeError("schedule not registered for '%s'" % target)

E   RuntimeError: schedule not registered for 'cuda'



tvm/_ffi/_cython/./base.pxi:159: TVMError

@masahi
Copy link
Member

masahi commented Mar 19, 2020

You don't need to test on cuda. Just do not use ctx_list()

Copy link
Contributor

@kazum kazum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, thanks!

@cchung100m
Copy link
Contributor Author

Hi @masahi @kazum @jwfromm

Thanks for all the reviews and suggestions. :)

@masahi masahi merged commit e1ebf06 into apache:master Mar 19, 2020
@masahi
Copy link
Member

masahi commented Mar 19, 2020

Thanks @cchung100m @kazum

@cchung100m cchung100m deleted the support_operator_NonZero_in_frontend_onnx branch March 20, 2020 00:44
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
* [Relay][Frontend][ONNX] operator support: NonZero

* update

* Solve the build fail

* solve the build fail

* Replace ctx_list with tvm.cpu()
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
* [Relay][Frontend][ONNX] operator support: NonZero

* update

* Solve the build fail

* solve the build fail

* Replace ctx_list with tvm.cpu()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants