-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay][Frontend][ONNX] operator support NonZero #5073
[Relay][Frontend][ONNX] operator support NonZero #5073
Conversation
@@ -30,7 +30,8 @@ | |||
import scipy | |||
|
|||
|
|||
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None): | |||
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', | |||
opset=None, mode=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, why does mode
need to be added? Does NonZero
only work in the VM executor? If so, maybe we can just exclusively use the VM for testing, I'm not sure there's a need to support both.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, It makes sense and I will put it into another function instead of adding it to get_tvm_output
.
python/tvm/relay/frontend/onnx.py
Outdated
def _impl_v9(cls, inputs, attr, params): | ||
if len(inputs) > 1: | ||
raise ValueError("Expect 1 input only") | ||
return AttrCvt(op_name='argwhere')(inputs, attr, params) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output shape of argwhere is different from the one of ONNX nonzero. Needs transpose here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review and I updated it.
|
||
mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset) | ||
|
||
ex = relay.create_executor(mode, mod=mod, ctx=tvm.cpu(), target=target) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tvm.cpu()
should be ctx
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the review and I updated it.
6c76ff8
to
fdd8733
Compare
return input_names, shape_dict | ||
|
||
|
||
def get_input_data_dtype_dict(graph_def, input_data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is not used anywhere.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @kazum
Thanks for the review and I can remove it here.
return result.asnumpy() | ||
|
||
|
||
def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about using vm for all the tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @kazum
I'd like to keep the relay.create_executor
and relay.build
both in this PR.
I cannot change the relay.build
with relay.create_executor
directly due to there are many errors like below:
File "/tvm/tests/python/frontend/onnx/test_forward.py", line 2282, in <module>
test_flatten()
File "/tvm/tests/python/frontend/onnx/test_forward.py", line 374, in test_flatten
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
File "/tvm/tests/python/frontend/onnx/test_forward.py", line 70, in get_tvm_output
result = ex.evaluate()(indata)
File "/tvm/python/tvm/relay/backend/vm.py", line 256, in _vm_wrapper
return self.vm.run(*args)
File "/tvm/python/tvm/runtime/vm.py", line 366, in run
return self.invoke("main", *args, **kwargs)
File "/tvm/python/tvm/runtime/vm.py", line 348, in invoke
return self._invoke(func_name)
File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
[bt] (7) 8 ??? 0x00007fff54230930 0x0 + 140734604970288
[bt] (6) 7 _ctypes.cpython-37m-darwin.so 0x00000001104dc2bf ffi_call_unix64 + 79
[bt] (5) 6 libtvm.dylib 0x0000000125071f78 TVMFuncCall + 72
[bt] (4) 5 libtvm.dylib 0x00000001250a4e3f std::__1::__function::__func<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0, std::__1::allocator<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 735
[bt] (3) 4 libtvm.dylib 0x00000001250a199a tvm::runtime::vm::VirtualMachine::RunLoop() + 7610
[bt] (2) 3 libtvm.dylib 0x00000001250a310f tvm::runtime::vm::VirtualMachine::InvokePacked(long long, tvm::runtime::PackedFunc const&, long long, long long, std::__1::vector<tvm::runtime::ObjectRef, std::__1::allocator<tvm::runtime::ObjectRef> > const&) + 1039
[bt] (1) 2 libtvm.dylib 0x000000012507b396 std::__1::__function::__func<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0, std::__1::allocator<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 310
[bt] (0) 1 libtvm.dylib 0x0000000124666af9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
File "/tvm/src/runtime/library_module.cc", line 89
TVMError: Check failed: ret == 0 (-1 vs. 0) : Assert fail: (((tvm_struct_get(arg0, 0, 5) == (uint8)2) && (tvm_struct_get(arg0, 0, 6) == (uint8)32)) && (tvm_struct_get(arg0, 0, 7) == (uint16)1)), arg0.dtype is expected to be float32
File "/tvm/tests/python/frontend/onnx/test_forward.py", line 2128, in verify_lstm
output_dtype=['float32', 'float32', 'float32'])
File "/tvm/tests/python/frontend/onnx/test_forward.py", line 69, in get_tvm_output
indata = tvm.nd.array(input_data)
File "/tvm/python/tvm/runtime/ndarray.py", line 487, in array
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)
File "/tvm/python/tvm/runtime/ndarray.py", line 270, in empty
dtype = DataType(dtype)
File "/tvm/python/tvm/_ffi/runtime_ctypes.py", line 101, in __init__
raise ValueError("Do not know how to handle type %s" % type_str)
ValueError: Do not know how to handle type object
Maybe we can initiate another PR to the above issues and change the relay.build
with relay.create_executor
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I'm fine with keeping both for now :)
fdd8733
to
c7e6b10
Compare
Hi @kazum It seems that I need to add the schedule of
|
You don't need to test on cuda. Just do not use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, thanks!
Thanks @cchung100m @kazum |
* [Relay][Frontend][ONNX] operator support: NonZero * update * Solve the build fail * solve the build fail * Replace ctx_list with tvm.cpu()
* [Relay][Frontend][ONNX] operator support: NonZero * update * Solve the build fail * solve the build fail * Replace ctx_list with tvm.cpu()
Hi @tqchen @vinx13 @zhiics
Following issue #4568, this PR is going to add NonZero operator and the test case for ONNX frontend. I would appreciate that if you can help me to review/manage this PR, thanks.