Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Frontend][ONNX] operator support NonZero #5073

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,6 +1443,18 @@ def _impl_v11(cls, inputs, attr, params):
return _op.image.resize(inputs[0], out_size, layout, method, coord_trans)


class NonZero(OnnxOpConverter):
"""Operator converter for NonZero
"""
@classmethod
def _impl_v9(cls, inputs, attr, params):
if len(inputs) > 1:
raise ValueError("Expect 1 input only")

output = AttrCvt(op_name='argwhere')(inputs, attr, params)
return _op.transpose(output, axes=(1, 0))


# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -1572,6 +1584,7 @@ def _get_convert_map(opset):
'Where': Where.get_converter(opset),
'Or': Or.get_converter(opset),
'Resize': Resize.get_converter(opset),
'NonZero': NonZero.get_converter(opset),
}


Expand Down
59 changes: 53 additions & 6 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,38 @@
import scipy


def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
""" Generic function to execute and get tvm output"""
target = 'llvm'
def get_input_data_shape_dict(graph_def, input_data):
if isinstance(input_data, list):
cchung100m marked this conversation as resolved.
Show resolved Hide resolved
input_names = {}
shape_dict = {}
dtype_dict = {}
for i, _ in enumerate(input_data):
input_names[i] = graph_def.graph.input[i].name
shape_dict[input_names[i]] = input_data[i].shape
dtype_dict[input_names[i]] = input_data[i].dtype
else:
input_names = graph_def.graph.input[0].name
shape_dict = {input_names: input_data.shape}
dtype_dict = {input_names: input_data.dtype}

return input_names, shape_dict


def get_tvm_output_with_vm(graph_def, input_data, target, ctx, opset=None):
""" Generic function to execute and get tvm output with vm executor"""

_, shape_dict = get_input_data_shape_dict(graph_def, input_data)

mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)

ex = relay.create_executor('vm', mod=mod, ctx=ctx, target=target)
indata = tvm.nd.array(input_data)
result = ex.evaluate()(indata)
return result.asnumpy()


def get_tvm_output(graph_def, input_data, target, ctx, output_shape=None, output_dtype='float32', opset=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about using vm for all the tests?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @kazum
I'd like to keep the relay.create_executor and relay.build both in this PR.

I cannot change the relay.build with relay.create_executor directly due to there are many errors like below:

  File "/tvm/tests/python/frontend/onnx/test_forward.py", line 2282, in <module>
    test_flatten()

  File "/tvm/tests/python/frontend/onnx/test_forward.py", line 374, in test_flatten
    tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')

  File "/tvm/tests/python/frontend/onnx/test_forward.py", line 70, in get_tvm_output
    result = ex.evaluate()(indata)

  File "/tvm/python/tvm/relay/backend/vm.py", line 256, in _vm_wrapper
    return self.vm.run(*args)

  File "/tvm/python/tvm/runtime/vm.py", line 366, in run
    return self.invoke("main", *args, **kwargs)

  File "/tvm/python/tvm/runtime/vm.py", line 348, in invoke
    return self._invoke(func_name)

  File "/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 213, in __call__
    raise get_last_ffi_error()

tvm._ffi.base.TVMError: Traceback (most recent call last):
  [bt] (7) 8   ???                                 0x00007fff54230930 0x0 + 140734604970288
  [bt] (6) 7   _ctypes.cpython-37m-darwin.so       0x00000001104dc2bf ffi_call_unix64 + 79
  [bt] (5) 6   libtvm.dylib                        0x0000000125071f78 TVMFuncCall + 72
  [bt] (4) 5   libtvm.dylib                        0x00000001250a4e3f std::__1::__function::__func<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0, std::__1::allocator<tvm::runtime::vm::VirtualMachine::GetFunction(std::__1::basic_string<char, std::__1::char_traits<char>, std::__1::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 735
  [bt] (3) 4   libtvm.dylib                        0x00000001250a199a tvm::runtime::vm::VirtualMachine::RunLoop() + 7610
  [bt] (2) 3   libtvm.dylib                        0x00000001250a310f tvm::runtime::vm::VirtualMachine::InvokePacked(long long, tvm::runtime::PackedFunc const&, long long, long long, std::__1::vector<tvm::runtime::ObjectRef, std::__1::allocator<tvm::runtime::ObjectRef> > const&) + 1039
  [bt] (1) 2   libtvm.dylib                        0x000000012507b396 std::__1::__function::__func<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0, std::__1::allocator<tvm::runtime::WrapPackedFunc(int (*)(TVMValue*, int*, int, TVMValue*, int*), tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0>, void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)>::operator()(tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&) + 310
  [bt] (0) 1   libtvm.dylib                        0x0000000124666af9 dmlc::LogMessageFatal::~LogMessageFatal() + 57
  File "/tvm/src/runtime/library_module.cc", line 89
TVMError: Check failed: ret == 0 (-1 vs. 0) : Assert fail: (((tvm_struct_get(arg0, 0, 5) == (uint8)2) && (tvm_struct_get(arg0, 0, 6) == (uint8)32)) && (tvm_struct_get(arg0, 0, 7) == (uint16)1)), arg0.dtype is expected to be float32

File "/tvm/tests/python/frontend/onnx/test_forward.py", line 2128, in verify_lstm
    output_dtype=['float32', 'float32', 'float32'])

  File "/tvm/tests/python/frontend/onnx/test_forward.py", line 69, in get_tvm_output
    indata = tvm.nd.array(input_data)

  File "/tvm/python/tvm/runtime/ndarray.py", line 487, in array
    return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)

  File "/tvm/python/tvm/runtime/ndarray.py", line 270, in empty
    dtype = DataType(dtype)

  File "/tvm/python/tvm/_ffi/runtime_ctypes.py", line 101, in __init__
    raise ValueError("Do not know how to handle type %s" % type_str)

ValueError: Do not know how to handle type object

Maybe we can initiate another PR to the above issues and change the relay.build with relay.create_executor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I'm fine with keeping both for now :)

""" Generic function to execute and get tvm output"""
target = 'llvm'

input_names, shape_dict = get_input_data_shape_dict(graph_def, input_data)

mod, params = relay.frontend.from_onnx(graph_def, shape_dict, opset=opset)
with relay.build_config(opt_level=1):
Expand Down Expand Up @@ -2209,6 +2226,35 @@ def verify(ishape, oshape, scales, mode, coord_trans):
verify([1, 16, 32, 32], [], [1, 1, 0.5, 0.5], "linear", "half_pixel")


def test_nonzero():

def verify_nonzero(indata, outdata, dtype):
node = helper.make_node('NonZero',
inputs=['X'],
outputs=['Y'],)

graph = helper.make_graph([node],
"nonzero_test",
inputs=[helper.make_tensor_value_info("X", TensorProto.INT64, list(indata.shape))],
outputs=[helper.make_tensor_value_info("Y", TensorProto.INT64, list(outdata.shape))])

model = helper.make_model(graph, producer_name='nonzero_test')

onnx_out = get_onnxruntime_output(model, indata, dtype)

for target, ctx in [('llvm', tvm.cpu())]:
tvm_out = get_tvm_output_with_vm(model, indata, target, ctx, opset=9)
tvm.testing.assert_allclose(onnx_out, tvm_out, rtol=1e-05, atol=1e-05)

input_data = np.array([[1, 0], [1, 1]], dtype=np.int64)
result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 1], [0, 0, 1]]
verify_nonzero(input_data, result, dtype=np.int64)

input_data = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]], dtype=np.int64)
result = np.array((np.nonzero(input_data))) # expected output [[0, 1, 2, 2], [0, 1, 0, 1]]
verify_nonzero(input_data, result, dtype=np.int64)


if __name__ == '__main__':
test_flatten()
test_reshape()
Expand Down Expand Up @@ -2269,3 +2315,4 @@ def verify(ishape, oshape, scales, mode, coord_trans):
test_pooling()
test_lstm()
test_resize()
test_nonzero()