From b6e6cf79332b79288c9976dc7cd75cf97eb7f6dd Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 12 Mar 2020 19:10:37 +0530 Subject: [PATCH 01/28] Relay to ONNX converter --- python/tvm/relay/converter/__init__.py | 22 + python/tvm/relay/converter/onnx.py | 698 +++++++++++++++++++++++++ 2 files changed, 720 insertions(+) create mode 100644 python/tvm/relay/converter/__init__.py create mode 100644 python/tvm/relay/converter/onnx.py diff --git a/python/tvm/relay/converter/__init__.py b/python/tvm/relay/converter/__init__.py new file mode 100644 index 000000000000..d0246c682233 --- /dev/null +++ b/python/tvm/relay/converter/__init__.py @@ -0,0 +1,22 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Converters for Relay to others IRs +Contain the converters for converting the Relay to other IRs +""" + +from .onnx import to_onnx diff --git a/python/tvm/relay/converter/onnx.py b/python/tvm/relay/converter/onnx.py new file mode 100644 index 000000000000..05ce6e701dc6 --- /dev/null +++ b/python/tvm/relay/converter/onnx.py @@ -0,0 +1,698 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines, redefined-builtin +"""Relay to ONNX serialization """ + +import numpy +import onnx +import onnx.utils +from onnx import numpy_helper, OperatorSetIdProto, defs +import tvm +from tvm.autotvm.graph_tuner.utils.traverse_graph import _expr2graph_impl +from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple + +ONNX_OPSET_VERSONS_SUPPORTED = [11] + + +def tvm_array_to_list(arr): + return tuple(x.value for x in arr) + + +def get_onnx_version(): + return onnx.__version__ + + +def add_input(data, name, model_container): + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[data.dtype] + tensor_value_info = onnx.helper.make_tensor_value_info(name, dtype, shape=data.shape) + model_container.add_inputs([tensor_value_info]) + data_tensor = numpy_helper.from_array(data, name) + model_container.add_initializers([data_tensor]) + + +class OpConverter(object): + """ Operator converter Base Class. + """ + + @classmethod + def convert_attributes(cls, attrs): + """convert Relay attributes to ONNX attributes. + The derived classes should implement this method + if attributes are required by the operator + otherwise by default no attributes are passed + """ + return {} + + @classmethod + def convert(cls, node, model_container, node_list): + attrs = cls.convert_attributes(node['node'].attrs) + node = onnx.helper.make_node(cls.__name__, + node['input_names'], + node['output_names'], + **attrs) + model_container.add_nodes([node]) + + +def rename(op_name): + """ This method creates dynamic operator of name op_name with empty attributes + """ + return type(op_name, (OpConverter,), {}) + + +class Reshape(object): + """ Operator converter for Reshape. + """ + + @classmethod + def convert(cls, node, model_container, node_list): + """Converts Relay operator Reshape to ONNX operator. + Relay operator accepts shape as attribute but ONNX operator + accepts it as a input. + """ + + shape = numpy.asarray([a.value for a in node['node'].attrs.newshape], + dtype=numpy.int64) + input_name = 'shape{}'.format(node['output_names'][0]) + node = onnx.helper.make_node(cls.__name__, [node['input_names'][0], input_name], + node['output_names']) + model_container.add_nodes([node]) + add_input(shape, input_name, model_container) + + +class Conv(OpConverter): + """ Operator converter for Conv. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'group': attrs.get_int("groups"), + 'pads': attrs.get_int_tuple("padding"), + 'strides': attrs.get_int_tuple("strides"), + 'dilations': attrs.get_int_tuple("dilation"), + 'kernel_shape': attrs.get_int_tuple("kernel_size"), + } + + +class MaxPool(OpConverter): + """ Operator converter for MaxPool. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'pads': attrs.get_int_tuple("padding") + attrs.get_int_tuple("padding"), + 'strides': attrs.get_int_tuple("strides"), + 'kernel_shape': attrs.get_int_tuple("pool_size"), + } + + +class Transpose(OpConverter): + """ Operator converter for Transpose. + """ + + @classmethod + def convert_attributes(cls, attrs): + return {'perm': attrs.get_int_tuple("axes")} if attrs["axes"] else {} + + +class MatMul(OpConverter): + """ Operator converter for MatMul. + """ + + @classmethod + def convert(cls, node, model_container, node_list): + output_name = 'inter{}'.format(node['output_names'][0]) + transpose_node = onnx.helper.make_node(Transpose.__name__, + [node['input_names'][1]], + [output_name], + **{'perm': (1, 0)}) + model_container.add_nodes([transpose_node]) + + inputs = [node['input_names'][0], output_name] + matmul_node = onnx.helper.make_node(cls.__name__, inputs, node['output_names']) + model_container.add_nodes([matmul_node]) + + +class Flatten(OpConverter): + """ Operator converter for Flatten. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axis': 1, + } + + +class BatchNormalization(OpConverter): + """ Operator converter for BatchNormalization. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'epsilon': float(attrs.get_str('epsilon')), + 'axis': float(attrs.get_int('axis')), + } + + @classmethod + def convert(cls, node, model_container, node_list): + """Converts Relay operator batch_norm to ONNX operator. + Relay operator has property axis to handle data in NHWC format. + """ + attrs = cls.convert_attributes(node['node'].attrs) + transpose_out_name = node['input_names'][0] + output_names = node['output_names'] + + # axis==3 means channel is specified along the 3rd axis + if attrs['axis'] == 3: + transpose_out_name = 'transpose_{}'.format(node['output_names'][0]) + node_transposed = onnx.helper.make_node(Transpose.__name__, + [node['input_names'][0]], + [transpose_out_name], + **{'perm': [0, 3, 1, 2]}) + model_container.add_nodes([node_transposed]) + output_names = ['batch_norm_{}'.format(node['output_names'][0])] + + batch_norm_node = onnx.helper.make_node(cls.__name__, + [transpose_out_name] + node['input_names'][1:], + output_names, + **{'epsilon': attrs['epsilon']}) + model_container.add_nodes([batch_norm_node]) + + if attrs['axis'] == 3: + node_transposed = onnx.helper.make_node(Transpose.__name__, + output_names, + node['output_names'], + **{'perm': [0, 2, 3, 1]}) + model_container.add_nodes([node_transposed]) + + +class Dropout(OpConverter): + """ Operator converter for Dropout. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'ratio': float(attrs.get_str('rate')), + } + + +class AveragePool(MaxPool): + """ Operator converter for AveragePool. + """ + + +class Concat(OpConverter): + """ Operator converter for Concat. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axis': attrs.get_int("axis"), + } + + +class BiasAdd(OpConverter): + """ Operator converter for BiasAdd. + """ + + @classmethod + def convert(cls, node, model_container, node_list): + input_node = node_list[node['inputs'][0][0]] + data_ndim = len(input_node['types'][0].shape) + axis = node['node'].attrs.get_int("axis") + if axis < 0: + axis = axis + data_ndim + new_axes = data_ndim - axis - 1 + if new_axes: + output_name = 'inter{}'.format(node['output_names'][0]) + unsqueeze_node = onnx.helper.make_node('Unsqueeze', + [node['input_names'][1]], + [output_name], + **{'axes': tuple(range(1, new_axes + 1))}) + model_container.add_nodes([unsqueeze_node]) + else: + output_name = node['input_names'][1] + + inputs = [node['input_names'][0], output_name] + matmul_node = onnx.helper.make_node('Add', inputs, node['output_names']) + model_container.add_nodes([matmul_node]) + + +class ReduceMean(OpConverter): + """ Operator converter for ReduceMean. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axes': attrs.axis, + 'keepdims': 0 if bool(attrs.get_int("keepdims", 0)) is False else 1 + } + + @classmethod + def convert(cls, node, model_container, node_list): + input_node = node_list[node['inputs'][0][0]] + shape = input_node['types'][0].shape + axis = node['node'].attrs.axis + axis = list(range(shape.size())) if not axis else tvm_array_to_list(axis) + exclude = 0 if not bool(node['node'].attrs.exclude) else 1 + keepdims = 0 if not bool(node['node'].attrs.keepdims) else 1 + if exclude: + all_axis = list(range(len(shape))) + axis = set(all_axis) - set(axis) + + node = onnx.helper.make_node(cls.__name__, + node['input_names'], + node['output_names'], + **{"axes": axis, + "keepdims": keepdims}) + model_container.add_nodes([node]) + + +class Pad(OpConverter): + """ Operator converter for Pad. + """ + + @classmethod + def convert_attributes(cls, attrs): + before = [] + after = [] + for axis_pads in attrs.pad_width: + before.append(axis_pads[0]) + after.append(axis_pads[1]) + pads = before + after + pads = numpy.asarray(pads, dtype=pads[0].dtype) + return { + 'pads': pads, + 'mode': attrs.get_str('pad_mode'), + 'constant_value': attrs.pad_value + } + + @classmethod + def convert(cls, node, model_container, node_list): + """Converts Relay operator Pad to ONNX operator. + Relay operator accepts pads as attribute but ONNX operator + accepts it as a input. + """ + attrs = cls.convert_attributes(node['node'].attrs) + + data = numpy.asarray(attrs['pads'], dtype=attrs['pads'][0].dtype).astype(numpy.int64) + input_name = 'pads_{}'.format(node['output_names'][0]) + value = numpy.dtype(node['types'][0].dtype).type(attrs['constant_value']) + input_value_name = 'value_{}'.format(node['output_names'][0]) + add_input(data, input_name, model_container) + add_input(value, input_value_name, model_container) + + input_names = [node['input_names'][0], input_name, input_value_name] + node = onnx.helper.make_node(cls.__name__, input_names, node['output_names']) + model_container.add_nodes([node]) + + +class Softmax(OpConverter): + """ Operator converter for SoftMax. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axis': attrs.axis, + } + + +class Squeeze(OpConverter): + """ Operator converter for Squeeze. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'axes': attrs.axis, + } + + @classmethod + def convert(cls, node, model_container, node_list): + input_node = node_list[node['inputs'][0][0]] + shape = input_node['types'][0].shape + axis = node['node'].attrs.get_int("axis") + if not axis: + axis = [] + for axis_idx, val in enumerate(shape): + if val.value == 1: + axis.append(axis_idx) + else: + axis = node['node'].attrs.get_int_tuple("axis") + + node = onnx.helper.make_node(cls.__name__, + node['input_names'], + node['output_names'], + **{"axes": axis}) + model_container.add_nodes([node]) + + +class Slice(OpConverter): + """ Operator converter for Slice. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'starts': attrs.get_int_tuple('begin'), + 'ends': attrs.get_int_tuple('end'), + 'steps': attrs.get_int_tuple('strides') + } + + @classmethod + def convert(cls, node, model_container, node_list): + attrs = cls.convert_attributes(node['node'].attrs) + + input_node = node_list[node['inputs'][0][0]] + shape = input_node['types'][0].shape + starts = list(attrs['starts']) + ends = list(attrs['ends']) + for i in range(len(starts), len(shape)): + starts.append(0) + for i in range(len(ends), len(shape)): + ends.append(shape[i] + 1) + + starts = numpy.asarray(starts).astype(numpy.int64) + starts_name = 'starts_{}'.format(node['output_names'][0]) + add_input(starts, starts_name, model_container) + + ends = numpy.asarray(ends).astype(numpy.int64) + ends_name = 'ends_{}'.format(node['output_names'][0]) + add_input(ends, ends_name, model_container) + + input_names = node['input_names'] + [starts_name, ends_name] + + if attrs['steps']: + axes = list(range(len(shape))) + attrs['axes'] = axes + assert len(axes) == len(attrs['steps']), "axes and steps should be of same size" + + steps = numpy.asarray(attrs['steps']).astype(numpy.int64) + steps_name = 'steps_{}'.format(node['output_names'][0]) + add_input(steps, steps_name, model_container) + + axes = numpy.asarray(attrs['axes']).astype(numpy.int64) + axes_name = 'axes_{}'.format(node['output_names'][0]) + add_input(axes, axes_name, model_container) + + input_names = input_names + [axes_name, steps_name] + + slice_node = onnx.helper.make_node(cls.__name__, + input_names, + node['output_names']) + model_container.add_nodes([slice_node]) + + +class ConstantOfShapeZeros(OpConverter): + """ Operator converter for ConstantOfShape. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'value': 0 + } + + @classmethod + def convert(cls, node, model_container, node_list): + attrs = cls.convert_attributes(node['node'].attrs) + input_node = node_list[node['inputs'][0][0]] + shape = input_node['types'][0].shape + dtype = input_node['types'][0].dtype + input_shape_name = 'shape_{}'.format(node['output_names'][0]) + shape = numpy.asarray(shape).astype(numpy.int64) + add_input(shape, input_shape_name, model_container) + + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(dtype)] + tensor_value = onnx.helper.make_tensor("value", dtype, + [1], [attrs['value']]) + + node = onnx.helper.make_node('ConstantOfShape', + [input_shape_name], + node['output_names'], + **{'value': tensor_value}) + model_container.add_nodes([node]) + + +class ConstantOfShapeOnes(ConstantOfShapeZeros): + """ Operator converter for ConstantOfShape. + """ + + @classmethod + def convert_attributes(cls, attrs): + return { + 'value': 1 + } + + +relay_to_onnx_op_mapping = { + 'reshape': Reshape, + 'nn.conv2d': Conv, + 'add': rename('Add'), + 'nn.relu': rename('Relu'), + 'transpose': Transpose, + 'nn.dense': MatMul, + 'nn.max_pool2d': MaxPool, + 'nn.batch_flatten': Flatten, + 'multiply': rename('Mul'), + 'nn.bias_add': BiasAdd, + 'nn.batch_norm': BatchNormalization, + 'nn.global_avg_pool2d': rename('GlobalAveragePool'), + 'concatenate': Concat, + 'nn.dropout': Dropout, + 'nn.avg_pool2d': AveragePool, + 'divide': rename('Div'), + 'mean': ReduceMean, + 'nn.pad': Pad, + 'nn.softmax': Softmax, + 'squeeze': Squeeze, + 'strided_slice': Slice, + 'greater': rename('Greater'), + 'less': rename('Less'), + 'equal': rename('Equal'), + 'zeros_like': ConstantOfShapeZeros, + 'ones_like': ConstantOfShapeOnes, + 'subtract': rename('Sub') +} + + +class ModelContainer(object): + """ A container class to hold different attributes of ONNX model graph + """ + + def __init__(self, name, opset_version): + self._name = name + self._opset_version = opset_version + self._inputs = [] + self._outputs = [] + self._nodes = [] + self._initializers = [] + + def add_inputs(self, inputs): + self._inputs.extend(inputs) + + def add_outputs(self, outputs): + self._outputs.extend(outputs) + + def add_nodes(self, nodes): + self._nodes.extend(nodes) + + def add_initializers(self, initializers): + self._initializers.extend(initializers) + + def _get_opsets(self): + opsets = [] + imp = OperatorSetIdProto() + imp.version = self._opset_version + opsets.append(imp) + return opsets + + def make_model(self): + """ Creates the onnx model from the graph """ + onnx_graph = onnx.helper.make_graph( + self._nodes, + self._name, + self._inputs, + self._outputs, + self._initializers + ) + kwargs = {} + kwargs["opset_imports"] = self._get_opsets() + kwargs["producer_name"] = 'TVM Relay' + kwargs["producer_name"] = tvm.__version__ + + return onnx.helper.make_model(onnx_graph, **kwargs) + + +class RelayToONNXConverter(object): + """A helper class converting topologically sorted Relay nodes to ONNX model + + Parameters + ---------- + name : str + name of the model + + node_list : list + topologically sorted Relay Node entry list + """ + + def __init__(self, name, node_list, params, opset_version): + self._name = {} + self._mc = ModelContainer(name, opset_version) + self._node_list = node_list + self._params = params + + def convert_to_onnx(self): + """ Loop through topologically sorted list of Relay nodes and generate a ONNX model""" + for idx, node_entry in enumerate(self._node_list): + out_idx = idx + node = node_entry['node'] + if isinstance(node, Call): + self._add_node(node_entry, idx) + elif isinstance(node, Var): + self._add_input(node_entry, idx) + elif isinstance(node, Constant): + self._add_constant_input(node_entry, idx) + elif isinstance(node, (TupleGetItem, Tuple)): + out_idx = idx - 1 # TODO: Need to work on this. + # No equivalent ONNX operator found yet + else: + raise NotImplementedError("Relay Node of type {0} is not " + "implemented yet".format(type(node))) + + if idx == len(self._node_list) - 1: + self._add_output(self._node_list[out_idx], out_idx) + + model = self._mc.make_model() + polished_model = onnx.utils.polish_model(model) + return polished_model + + def _tuple_to_name(self, input): + """convert tuple of node indexes to string""" + return 'node_{0}'.format(input[0]) + + def _add_node(self, node_entry, idx): + """Convert Relay operator node to ONNX operator and add it to container nodes list""" + if node_entry['op'].name not in relay_to_onnx_op_mapping: + raise NotImplementedError("Currently the operator '{0}' is " + "not supported.".format(node_entry['op'].name)) + + converter = relay_to_onnx_op_mapping[node_entry['op'].name]() + node_entry['output_names'] = [self._tuple_to_name([idx, 0, 0])] + node_entry['input_names'] = [] + for input_idx_tuple in node_entry['inputs']: + if self._node_list[input_idx_tuple[0]]['name']: + node_entry['input_names'].append(self._node_list[input_idx_tuple[0]]['name']) + else: + node_entry['input_names'].append(self._tuple_to_name(input_idx_tuple)) + + converter.convert(node_entry, self._mc, self._node_list) + + def _add_params(self, node_entry, idx): + """Add param value to initializer and name to inputs""" + param_name = node_entry['name'] + assert param_name in self._params, "The parameter {0} is not present" \ + "in params dict provided.".format(param_name) + value = self._params[param_name] + numpy_array = value.asnumpy() + tensor = numpy_helper.from_array(numpy_array, param_name) + self._mc.add_initializers([tensor]) + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy_array.dtype] + input = onnx.helper.make_tensor_value_info(param_name, + dtype, + shape=numpy_array.shape) + self._mc.add_inputs([input]) + + def _add_constant_input(self, node_entry, idx): + """Create named input for constant and add it to container inputs. + If input is a parameter then add to param + """ + node = node_entry['node'] + if not node_entry['name']: + node_entry['name'] = self._tuple_to_name([idx, 0, 0]) + param_name = node_entry['name'] + self._params[param_name] = node.data + self._add_params(node_entry, idx) + + def _add_input(self, node_entry, idx): + """Add input node to container inputs. If input is a parameter then add to param""" + if node_entry['name'] in self._params: + self._add_params(node_entry, idx) + else: + type = node_entry['types'][0] + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(type.dtype)] + input = onnx.helper.make_tensor_value_info(node_entry['name'], + dtype, + shape=type.concrete_shape) + self._mc.add_inputs([input]) + + def _add_output(self, node_entry, idx): + """Add output node to container outputs.""" + + type = node_entry['types'][0] + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(type.dtype)] + output = onnx.helper.make_tensor_value_info(self._tuple_to_name([idx, 0, 0]), + dtype, + shape=type.concrete_shape) + self._mc.add_outputs([output]) + + +def to_onnx(relay_module, params, name, opset_version=11, path=None): + """Convert a Relay Function Module into an equivalent ONNX and serialize it to the path + + Parameters + ---------- + relay_module : tvm.relay.Module + The relay module object + + params : dict + dict of the parameter names and NDarray values + + path : str + The path where ONNX model will be saved + + Returns + ------- + inferred_model : tvm.relay.Module + The relay module + + """ + + if opset_version not in ONNX_OPSET_VERSONS_SUPPORTED: + raise NotImplementedError("Currently only opset version 11 is supported.") + + if opset_version > defs.onnx_opset_version(): + raise Exception("The ONNX package installed of version {} does not support the opset " + "version {}. Upgrade the ONNX package to latest version.".format( + get_onnx_version(), opset_version)) + + node_list = [] # ONNX needs a topologically sorted list of nodes + node_dict = {} + _expr2graph_impl(relay_module["main"], [], node_dict, node_list) + converter = RelayToONNXConverter(name, node_list, params, opset_version) + onnx_model = converter.convert_to_onnx() + + if path: + onnx.save(onnx_model, path) + return onnx_model From 8319b76b5d8e86d63742caa6c7dd38f6569603d6 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 12 Mar 2020 19:11:16 +0530 Subject: [PATCH 02/28] Relay to ONNX op test cases --- tests/python/relay/converter/test_onnx.py | 403 ++++++++++++++++++++++ 1 file changed, 403 insertions(+) create mode 100644 tests/python/relay/converter/test_onnx.py diff --git a/tests/python/relay/converter/test_onnx.py b/tests/python/relay/converter/test_onnx.py new file mode 100644 index 000000000000..f52a9c12644c --- /dev/null +++ b/tests/python/relay/converter/test_onnx.py @@ -0,0 +1,403 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Relay to ONNX serialization test cases""" +import numpy as np +import tvm +from tvm import relay +from tvm.relay.converter import to_onnx +import onnxruntime as rt + + +def func_to_onnx(func, name): + mod = tvm.IRModule() + mod['main'] = func + onnx_model = to_onnx(mod, {}, name, path=None) + return onnx_model.SerializeToString() + + +def run_onnx(onnx_model, input_data): + sess = rt.InferenceSession(onnx_model) + input_names = {} + for input, data in zip(sess.get_inputs(), input_data): + input_names[input.name] = data + output_name = sess.get_outputs()[0].name + res = sess.run([output_name], input_names) + return res[0] + + +def run_relay(func, data_tuple): + target = 'llvm' + ctx = tvm.context('llvm', 0) + intrp = relay.create_executor("graph", ctx=ctx, target=target) + relay_res = intrp.evaluate(func)(*data_tuple) + return relay_res.asnumpy() + + +def verify_results(relay_func, indata, test_name, rtol=1e-7, atol=0): + relay_res = run_relay(relay_func, indata) + onnx_res = run_onnx(func_to_onnx(relay_func, test_name), indata) + np.testing.assert_allclose(relay_res, onnx_res, rtol=rtol, atol=atol) + + +def test_add(): + dtype = 'float32' + t1 = relay.TensorType((5, 10, 5)) + t2 = relay.TensorType((5, 10, 5)) + x = relay.var("x", t1, dtype=dtype) + y = relay.var("y", t2, dtype=dtype) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + x_data = np.random.rand(5, 10, 5).astype(dtype) + y_data = np.random.rand(5, 10, 5).astype(dtype) + + verify_results(func, [x_data, y_data], 'test_add') + + +def test_bias_add(): + for dtype in ['float16', 'float32']: + xshape = (10, 2, 3, 4) + bshape = (2,) + rtol = 1e-2 if dtype is 'float16' else 1e-5 + x = relay.var("x", shape=xshape, dtype=dtype) + bias = relay.var("bias", dtype=dtype) + z = relay.nn.bias_add(x, bias) + func = relay.Function([x, bias], z) + + x_data = np.random.uniform(size=xshape).astype(dtype) + y_data = np.random.uniform(size=bshape).astype(dtype) + + verify_results(func, [x_data, y_data], 'test_bias_add', rtol=rtol) + + +def test_conv2d(): + def verify_conv2d(dtype, scale, dshape, kshape, + padding=(1, 1), + groups=1, + dilation=(1, 1), + **attrs): + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", shape=kshape, dtype=dtype) + y = relay.nn.conv2d(x, w, + padding=padding, + dilation=dilation, + groups=groups, + **attrs) + func = relay.Function([x, w], y) + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + verify_results(func, [data, kernel], 'test_conv2d', rtol=1e-5, atol=1e-5) + + dshape = (1, 32, 18, 18) + kshape = (32, 1, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=32, groups=32, kernel_size=(3, 3)) + + dshape = (1, 32, 18, 18) + kshape = (32, 4, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=32, groups=8, kernel_size=(3, 3)) + + # also group conv2d + dshape = (1, 32, 18, 18) + kshape = (64, 1, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=64, groups=32, kernel_size=(3, 3)) + + # normal conv2d + dshape = (1, 3, 224, 224) + kshape = (10, 3, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(3, 3)) + + dshape = (1, 3, 224, 224) + kshape = (10, 3, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(2, 2), channels=10, kernel_size=(3, 3)) + + dshape = (1, 3, 18, 18) + kshape = (10, 3, 3, 3) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(3, 3), dilation=(3, 3)) + + dshape = (1, 3, 18, 18) + kshape = (10, 3, 2, 2) + verify_conv2d("float32", 1, dshape, kshape, + padding=(2, 2), channels=10, kernel_size=(2, 2), dilation=(1, 1)) + + dshape = (1, 3, 18, 18) + kshape = (10, 3, 4, 4) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(4, 4)) + + dshape = (1, 3, 18, 18) + kshape = (10, 3, 4, 4) + verify_conv2d("float32", 1, dshape, kshape, + padding=(1, 1), channels=10, kernel_size=(4, 4)) + + +def test_reshape(): + def verify_reshape(shape, newshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.reshape(x, newshape=newshape) + + func = relay.Function([x], z) + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + verify_results(func, [x_data], 'test_reshape', rtol=1e-5, atol=1e-5) + + verify_reshape((2, 3, 4), tuple(np.array([4, 2, 3], dtype=np.int64))) + verify_reshape((2, 3, 4), tuple(np.array([2, 0, 0], dtype=np.int64))) + verify_reshape((2, 3, 4), tuple(np.array([0, -1], dtype=np.int64))) + verify_reshape((2, 3, 4), tuple(np.array([-1, 0], dtype=np.int64))) + + +def test_transpose(): + def verify_reshape(shape, newshape): + x = relay.var("x", relay.TensorType(shape, "float32")) + z = relay.transpose(x, newshape) + func = relay.Function([x], z) + x_data = np.random.uniform(low=-1, high=1, size=shape).astype("float32") + verify_results(func, [x_data], 'test_transpose', rtol=1e-5, atol=1e-5) + + verify_reshape((1, 2, 3, 4), (0, 2, 3, 1)) + verify_reshape((1, 2, 3, 4), (0, 3, 2, 1)) + + +def test_dense(): + def verify_dense(d_shape, w_shape): + data = relay.var("data", relay.TensorType(d_shape, "float32")) + weight = relay.var("weight", relay.TensorType(w_shape, "float32")) + func = relay.Function([data, weight], relay.nn.dense(data, weight)) + x_data = np.random.uniform(size=d_shape).astype("float32") + w_data = np.random.uniform(size=w_shape).astype("float32") + verify_results(func, [x_data, w_data], 'test_dense', rtol=1e-5, atol=1e-5) + + verify_dense((1, 8), (16, 8)) + verify_dense((1, 4), (3, 4)) + + +def test_max_pool(): + def verify_max_pool(x_shape, pool_size, strides, padding, ceil_mode): + x = relay.var("x", relay.TensorType(x_shape, "float32")) + y = tvm.relay.nn.max_pool2d(x, pool_size=pool_size, strides=strides, padding=padding, + ceil_mode=ceil_mode) + func = relay.Function([x], y) + x_data = np.random.uniform(size=x_shape).astype("float32") + verify_results(func, [x_data], 'test_max_pool', rtol=1e-5, atol=1e-5) + + verify_max_pool((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False) + + +def test_batch_flatten(): + def verify_test_batch_flatten(d_shape): + data = relay.var("data", relay.TensorType(d_shape, "float32")) + func = relay.Function([data], relay.nn.batch_flatten(data)) + x_data = np.random.uniform(size=d_shape).astype("float32") + verify_results(func, [x_data], 'test_batch_flatten', rtol=1e-5, atol=1e-5) + + verify_test_batch_flatten((1, 2, 3, 4)) + verify_test_batch_flatten((1, 8)) + + +def test_bias_add(): + def verify_bias_add(): + data = relay.var("data", relay.TensorType((1, 16), "float32")) + bias = relay.var("bias", relay.TensorType((16,), "float32")) + func = relay.Function([data, bias], relay.nn.bias_add(data, bias)) + + x_data = np.random.uniform(size=(1, 16)).astype("float32") + bias = np.random.uniform(size=(16,)).astype("float32") + verify_results(func, [x_data, bias], 'test_bias_add', rtol=1e-5, atol=1e-5) + + verify_bias_add() + + +def test_batch_norm(): + def verify_batch_norm(axis=1): + for dtype in ['float16', 'float32']: + data = relay.var("data", relay.TensorType((2, 4, 4, 1), dtype)) + gamma_shape = (data.type_annotation.shape[axis].value,) + beta = relay.var("beta", relay.TensorType(gamma_shape, dtype)) + gamma = relay.var("gamma", relay.TensorType(gamma_shape, dtype)) + moving_mean = relay.var("moving_mean", relay.TensorType(gamma_shape, dtype)) + moving_var = relay.var("moving_var", relay.TensorType(gamma_shape, dtype)) + y = relay.nn.batch_norm(data, gamma, beta, moving_mean, moving_var, axis=axis) + func = relay.Function([data, gamma, beta, moving_mean, moving_var], y[0]) + + x_data = np.random.uniform(size=(2, 4, 4, 1)).astype(dtype) + beta = np.random.uniform(size=gamma_shape).astype(dtype) + gamma = np.random.uniform(size=gamma_shape).astype(dtype) + moving_mean = np.random.uniform(size=gamma_shape).astype(dtype) + moving_var = np.random.uniform(size=gamma_shape).astype(dtype) + verify_results(func, [x_data, gamma, beta, moving_mean, moving_var], 'test_batch_norm', rtol=1e-3, + atol=1e-3) + + verify_batch_norm(axis=1) + verify_batch_norm(axis=3) + + +def test_pad(): + def verify_pad(): + for dtype in ['float16', 'float32']: + dshape = (4, 10, 7, 7) + x = relay.var("x", shape=dshape, dtype=dtype) + y = relay.nn.pad(x, ((1, 1), (2, 2), (3, 3), (4, 4))) + func = relay.Function([x], y) + x_data = np.random.uniform(size=dshape).astype(dtype) + verify_results(func, [x_data], 'test_pad', rtol=1e-5, atol=1e-5) + + verify_pad() + + +def test_sofmax(): + def verify_sofmax(): + for dtype in ['float32']: + shape = (10, 4) + x = relay.var("x", shape=shape, dtype=dtype) + y = relay.nn.softmax(x, axis=1) + func = relay.Function([x], y) + x_data = np.random.uniform(size=shape).astype(dtype) + verify_results(func, [x_data], 'test_softmax', rtol=1e-5, atol=1e-5) + + verify_sofmax() + + +def test_squeeze(): + def verify_squeeze(shape, dtype, axis): + x = relay.var("x", relay.TensorType(shape, dtype)) + z = relay.squeeze(x, axis=axis) + func = relay.Function([x], z) + x_data = np.random.random_sample(shape).astype(dtype) + verify_results(func, [x_data], 'test_squeeze', rtol=1e-5, atol=1e-5) + + verify_squeeze((1, 3, 2, 5), "float32", None) + verify_squeeze((1, 3, 1), "float32", [2, ]) + verify_squeeze((1, 2, 1, 2, 1), "float32", [0, 2]) + + +def test_mean(): + def verify_mean(data_shape, axis, exclude, keepdims): + dtype = "float32" + x = relay.var('x', shape=data_shape, dtype=dtype) + y = relay.mean(x, axis, keepdims, exclude) + func = relay.Function([x], y) + x_data = np.random.uniform(size=data_shape).astype(dtype) + verify_results(func, [x_data], 'test_mean', rtol=1e-5, atol=1e-5) + + verify_mean((1, 2), 0, False, False) + verify_mean((1, 2), 0, True, False) + verify_mean((1, 2), 0, True, True) + verify_mean((1, 2), 1, True, True) + verify_mean((3, 2, 1), 1, False, True) + + +def test_strided_slice(): + def verify_strided_slice(dshape, begin, end, strides): + x = relay.var("x", relay.TensorType(dshape, "float32")) + z = relay.strided_slice(x, begin=begin, end=end, strides=strides) + func = relay.Function([x], z) + x_data = np.random.uniform(size=dshape).astype("float32") + verify_results(func, [x_data], 'test_strided_slice', rtol=1e-5, atol=1e-5) + + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], None) + verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) + verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1]) + verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2]) + verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) + + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3], None) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None) + verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None) + + # TODO - test cases below fails for TVM itself error -strided_slice get empty slice at axis 1 + # verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], [1, -1, 1]) + # verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], [1, 1, 2]) + # verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], [1, -1, 1]) + # verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], [1, 1, 2]) + + +def test_cmp_type(): + for op, ref in ((relay.greater, np.greater), + (relay.less, np.less), + (relay.equal, np.equal) + ): + x_shape = (10, 4) + y_shape = (5, 10, 1) + t1 = relay.TensorType(x_shape) + t2 = relay.TensorType(y_shape) + x = relay.var("x", t1) + y = relay.var("y", t2) + z = op(x, y) + x_data = np.random.rand(*x_shape).astype(t1.dtype) + y_data = np.random.rand(*y_shape).astype(t2.dtype) + func = relay.Function([x, y], z) + verify_results(func, [x_data, y_data], 'test_cmp_type', rtol=1e-5, atol=1e-5) + + +def test_unary_identity(): + for dtype in ["int16", "float32", "float64"]: + for op, ref in [(relay.zeros_like, np.zeros_like), + (relay.ones_like, np.ones_like)]: + shape = (8, 9, 4) + x = relay.var("x", relay.TensorType(shape, dtype)) + y = op(x) + func = relay.Function([x, ], y) + x_data = np.random.rand(*shape).astype(dtype) + verify_results(func, [x_data], 'test_cmp_type', rtol=1e-5, atol=1e-5) + + +def test_binary_op(): + def check_binary_op(opfunc, dtype): + t1 = relay.TensorType((5, 10, 5)) + t2 = relay.TensorType((5, 10, 5)) + x = relay.var("x", t1, dtype=dtype) + y = relay.var("y", t2, dtype=dtype) + z = opfunc(x, y) + x_data = np.random.rand(5, 10, 5).astype(dtype) + y_data = np.random.rand(5, 10, 5).astype(dtype) + func = relay.Function([x, y], z) + verify_results(func, [x_data, y_data], 'test_binary_op', rtol=1e-5, atol=1e-5) + + for opfunc, ref in [(relay.add, np.add), + (relay.subtract, np.subtract), + (relay.multiply, np.multiply), + (relay.divide, np.divide), + ]: + for dtype in ['float32']: + check_binary_op(opfunc, dtype) + + +if __name__ == '__main__': + test_add() + test_bias_add() + test_conv2d() + test_reshape() + test_transpose() + test_dense() + test_max_pool() + test_batch_flatten() + test_bias_add() + test_batch_norm() + test_pad() + test_mean() + test_sofmax() + test_squeeze() + test_strided_slice() + test_cmp_type() + test_binary_op() From 72f6b2b82de323174f1d34ce5cbf65f87c36a523 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 12 Mar 2020 19:11:34 +0530 Subject: [PATCH 03/28] Relay to ONNX end to end model test cases --- tests/python/relay/converter/test_model.py | 85 ++++++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/python/relay/converter/test_model.py diff --git a/tests/python/relay/converter/test_model.py b/tests/python/relay/converter/test_model.py new file mode 100644 index 000000000000..d8e982e58cca --- /dev/null +++ b/tests/python/relay/converter/test_model.py @@ -0,0 +1,85 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Relay to ONNX serialization test cases""" +from collections import OrderedDict +import numpy as np +import tvm +from tvm import relay +from tvm.relay.converter import to_onnx +import onnxruntime as rt +import tvm.relay.testing + + +def func_to_onnx(mod, params, name): + onnx_model = to_onnx(mod, params, name, path=None) + return onnx_model.SerializeToString() + + +def run_onnx(mod, params, name, input_data): + onnx_model = func_to_onnx(mod, params, name) + sess = rt.InferenceSession(onnx_model) + input_names = {} + for input, data in zip(sess.get_inputs(), input_data): + input_names[input.name] = data + output_names = [output.name for output in sess.get_outputs()] + res = sess.run(output_names, input_names) + return res[0] + + +def get_data(in_data_shapes, dtype='float32'): + in_data = OrderedDict() + for name, shape in in_data_shapes.items(): + in_data[name] = np.random.uniform(size=shape).astype(dtype) + return in_data + + +def run_relay(mod, params, in_data): + target = 'llvm' + ctx = tvm.context('llvm', 0) + intrp = relay.create_executor("graph", mod, ctx=ctx, target=target) + in_data = [tvm.nd.array(value) for value in in_data.values()] + return intrp.evaluate()(*in_data, **params).asnumpy() + + +def _verify_results(mod, params, in_data): + a = run_relay(mod, params, in_data) + b = run_onnx(mod, params, 'test_resent', in_data.values()) + np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-7) + + +def test_resnet(): + num_class = 1000 + in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) + in_data = get_data(in_data_shapes, dtype="float32") + for n in [18, 34, 50, 101]: + mod, params = tvm.relay.testing.resnet.get_workload( + 1, num_class, num_layers=n) + _verify_results(mod, params, in_data) + + +def test_squeezenet(): + in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) + in_data = get_data(in_data_shapes, dtype="float32") + for version in ['1.0', '1.1']: + mod, params = tvm.relay.testing.squeezenet.get_workload(1, version=version) + _verify_results(mod, params, in_data) + + +if __name__ == '__main__': + test_resnet() + test_squeezenet() From 1fda66ae64cf46006073ccff1eb32f45d6b62b8f Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 12 Mar 2020 19:11:59 +0530 Subject: [PATCH 04/28] Add test cases to jenkins --- Jenkinsfile | 1 + tests/scripts/task_python_converter.sh | 35 ++++++++++++++++++++++++++ 2 files changed, 36 insertions(+) create mode 100755 tests/scripts/task_python_converter.sh diff --git a/Jenkinsfile b/Jenkinsfile index 60ee14249d28..43b42ad65307 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -254,6 +254,7 @@ stage('Integration Test') { unpack_lib('gpu', tvm_multilib) timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_frontend.sh" + sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_converter.sh" } } } diff --git a/tests/scripts/task_python_converter.sh b/tests/scripts/task_python_converter.sh new file mode 100755 index 000000000000..fffe2aac44be --- /dev/null +++ b/tests/scripts/task_python_converter.sh @@ -0,0 +1,35 @@ +#!/bin/bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -e +set -u + +export PYTHONPATH=python:topi/python +# to avoid openblas threading error +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + +find . -type f -path "*.pyc" | xargs rm -f + +# Rebuild cython +make cython3 + +echo "Running relay to ONNX converter..." +python3 -m pytest -v tests/python/relay/converter/ + + From 52b9bd168ed478d0ceae6df788de6e8ab0ce6446 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 12 Mar 2020 21:20:12 +0530 Subject: [PATCH 05/28] CI CD fixes --- tests/python/{relay => }/converter/test_model.py | 0 tests/python/{relay => }/converter/test_onnx.py | 2 +- tests/scripts/task_python_converter.sh | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename tests/python/{relay => }/converter/test_model.py (100%) rename tests/python/{relay => }/converter/test_onnx.py (99%) diff --git a/tests/python/relay/converter/test_model.py b/tests/python/converter/test_model.py similarity index 100% rename from tests/python/relay/converter/test_model.py rename to tests/python/converter/test_model.py diff --git a/tests/python/relay/converter/test_onnx.py b/tests/python/converter/test_onnx.py similarity index 99% rename from tests/python/relay/converter/test_onnx.py rename to tests/python/converter/test_onnx.py index f52a9c12644c..bc6a78dfe805 100644 --- a/tests/python/relay/converter/test_onnx.py +++ b/tests/python/converter/test_onnx.py @@ -73,7 +73,7 @@ def test_bias_add(): for dtype in ['float16', 'float32']: xshape = (10, 2, 3, 4) bshape = (2,) - rtol = 1e-2 if dtype is 'float16' else 1e-5 + rtol = 1e-2 if dtype == 'float16' else 1e-5 x = relay.var("x", shape=xshape, dtype=dtype) bias = relay.var("bias", dtype=dtype) z = relay.nn.bias_add(x, bias) diff --git a/tests/scripts/task_python_converter.sh b/tests/scripts/task_python_converter.sh index fffe2aac44be..7faccb2d6a4c 100755 --- a/tests/scripts/task_python_converter.sh +++ b/tests/scripts/task_python_converter.sh @@ -30,6 +30,6 @@ find . -type f -path "*.pyc" | xargs rm -f make cython3 echo "Running relay to ONNX converter..." -python3 -m pytest -v tests/python/relay/converter/ +python3 -m pytest -v tests/python/converter/ From 0730fb0dcf732545d4f85d7582c06143b284913b Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 14 May 2020 06:20:10 +0530 Subject: [PATCH 06/28] ONNX codegen --- .../modules/contrib/ONNX.cmake | 10 +- .../onnx.py => contrib/codegen_onnx.py} | 60 ++++++- src/target/source/onnx_module.cc | 85 +++++++++ .../{converter => contrib}/test_onnx.py | 3 +- tests/python/contrib/test_onnx_model.py | 162 ++++++++++++++++++ tests/python/converter/test_model.py | 85 --------- 6 files changed, 310 insertions(+), 95 deletions(-) rename python/tvm/relay/converter/__init__.py => cmake/modules/contrib/ONNX.cmake (79%) rename python/tvm/{relay/converter/onnx.py => contrib/codegen_onnx.py} (91%) create mode 100644 src/target/source/onnx_module.cc rename tests/python/{converter => contrib}/test_onnx.py (99%) create mode 100644 tests/python/contrib/test_onnx_model.py delete mode 100644 tests/python/converter/test_model.py diff --git a/python/tvm/relay/converter/__init__.py b/cmake/modules/contrib/ONNX.cmake similarity index 79% rename from python/tvm/relay/converter/__init__.py rename to cmake/modules/contrib/ONNX.cmake index d0246c682233..c4a791b372b3 100644 --- a/python/tvm/relay/converter/__init__.py +++ b/cmake/modules/contrib/ONNX.cmake @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -""" -Converters for Relay to others IRs -Contain the converters for converting the Relay to other IRs -""" -from .onnx import to_onnx +if(USE_ONNX_CODEGEN) + message(STATUS "Build with contrib.codegen_onnx") + file(GLOB ONNX_CONTRIB_SRC src/target/source/onnx_module.cc) + list(APPEND RUNTIME_SRCS ${ONNX_CONTRIB_SRC}) +endif(USE_ONNX_CODEGEN) diff --git a/python/tvm/relay/converter/onnx.py b/python/tvm/contrib/codegen_onnx.py similarity index 91% rename from python/tvm/relay/converter/onnx.py rename to python/tvm/contrib/codegen_onnx.py index 05ce6e701dc6..a5f4c7119501 100644 --- a/python/tvm/relay/converter/onnx.py +++ b/python/tvm/contrib/codegen_onnx.py @@ -17,11 +17,14 @@ # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines, redefined-builtin """Relay to ONNX serialization """ +import os +import struct import numpy import onnx import onnx.utils from onnx import numpy_helper, OperatorSetIdProto, defs import tvm +import tvm._ffi from tvm.autotvm.graph_tuner.utils.traverse_graph import _expr2graph_impl from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple @@ -658,12 +661,12 @@ def _add_output(self, node_entry, idx): self._mc.add_outputs([output]) -def to_onnx(relay_module, params, name, opset_version=11, path=None): +def to_onnx(relay_ir, params, name, opset_version=11, path=None): """Convert a Relay Function Module into an equivalent ONNX and serialize it to the path Parameters ---------- - relay_module : tvm.relay.Module + relay_ir : tvm.ir.IRModule or tvm.relay.Function The relay module object params : dict @@ -689,10 +692,61 @@ def to_onnx(relay_module, params, name, opset_version=11, path=None): node_list = [] # ONNX needs a topologically sorted list of nodes node_dict = {} - _expr2graph_impl(relay_module["main"], [], node_dict, node_list) + func = relay_ir["main"] if isinstance(relay_ir, tvm.ir.IRModule) else relay_ir + _expr2graph_impl(func, [], node_dict, node_list) converter = RelayToONNXConverter(name, node_list, params, opset_version) onnx_model = converter.convert_to_onnx() if path: onnx.save(onnx_model, path) return onnx_model + + +@tvm._ffi.register_func("relay.ext.onnx") +def onnx_compiler(ref): + """Create a runtime module for ONNX from IRModule + + :param ref: IRModule subgraphs for onnx codegen + :return: runtime module for ONNX + """ + data = b'' + if isinstance(ref, tvm.ir.module.IRModule): + for var, func in ref.functions.items(): + name = var.name_hint + model = to_onnx(func, {}, name) + name_bytes = bytes(name, 'utf-8') + name_size = struct.pack('I', len(name_bytes)) + model_serialized = model.SerializeToString() + model_size = struct.pack('I', model.ByteSize()) + + data += name_size + name_bytes + model_size + model_serialized + + runtime_func = "runtime.ONNXModuleCreate" + fcreate = tvm._ffi.get_global_func(runtime_func) + return fcreate(data.hex()) + + +@tvm._ffi.register_func("relay.ext.onnx.save_to_file") +def save_to_file(hex_str, path=None, fmt="onnx"): + """ Store the ONNX subgraphs in the path folder + + :param hex_str: Subgrah names and corresponding serialized onnx hex string + :param path: path to which ONNX files to be stored + It is assumed that path exists + :param fmt: extension of the files to be stored + """ + onnx_ir = bytes.fromhex(hex_str) + + offset = 0 + while offset < len(onnx_ir): + stop = offset + 4 + (name_size,) = struct.unpack('I', onnx_ir[offset:stop]) + name = onnx_ir[stop : stop + name_size].decode("utf-8") + stop = stop + name_size + (model_size,) = struct.unpack('I', onnx_ir[stop:stop + 4]) + stop = stop + 4 + model_serialized = onnx_ir[stop:stop + model_size] + offset = stop + model_size + + model_onnx = onnx.load_model_from_string(model_serialized) + onnx.save(model_onnx, "{}{}{}.{}".format(path, os.path.sep, name, fmt)) diff --git a/src/target/source/onnx_module.cc b/src/target/source/onnx_module.cc new file mode 100644 index 000000000000..5b148dda9f32 --- /dev/null +++ b/src/target/source/onnx_module.cc @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file onnx_module.cc + * \brief ONNX Module without runtime support + */ +#include +#include +#include "codegen_source_base.h" +#include "../../runtime/file_util.h" +#include "../../runtime/meta_data.h" + +namespace tvm { +namespace codegen { + +using runtime::TVMArgs; +using runtime::TVMRetValue; +using runtime::PackedFunc; + +using runtime::GetFileFormat; +using runtime::GetMetaFilePath; +using runtime::FunctionInfo; +using runtime::SaveBinaryToFile; + +class ONNXSourceModuleNode : public runtime::ModuleNode { + public: + explicit ONNXSourceModuleNode(String code) + : code_(code) {} + + const char* type_key() const { + return "onnx"; + } + + PackedFunc GetFunction( + const std::string& name, + const ObjectPtr& sptr_to_self) final { + LOG(FATAL) << "ONNX Source module cannot execute, to get executable module" + << " build TVM with onnx runtime support"; + return PackedFunc(); + } + + std::string GetSource(const std::string& format) final { + return code_; + } + + void SaveToFile(const std::string& path, + const std::string& format) final { + CHECK_EQ(format, "onnx") + << "Can only save to onnx format"; + CHECK_NE(code_.length(), 0); + const PackedFunc* to_onnx_ = runtime::Registry::Get("relay.ext.onnx.save_to_file"); + (*to_onnx_)(code_, path, format); + } + + protected: + String code_; +}; + +runtime::Module ONNXSourceModuleNodeCreate(String code) { + auto n = make_object(code); + return runtime::Module(n); +} + +TVM_REGISTER_GLOBAL("runtime.ONNXModuleCreate") +.set_body_typed(ONNXSourceModuleNodeCreate); + +} // namespace codegen +} // namespace tvm diff --git a/tests/python/converter/test_onnx.py b/tests/python/contrib/test_onnx.py similarity index 99% rename from tests/python/converter/test_onnx.py rename to tests/python/contrib/test_onnx.py index bc6a78dfe805..a6f954d773eb 100644 --- a/tests/python/converter/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -19,7 +19,7 @@ import numpy as np import tvm from tvm import relay -from tvm.relay.converter import to_onnx +from tvm.contrib.codegen_onnx import to_onnx import onnxruntime as rt @@ -382,7 +382,6 @@ def check_binary_op(opfunc, dtype): for dtype in ['float32']: check_binary_op(opfunc, dtype) - if __name__ == '__main__': test_add() test_bias_add() diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py new file mode 100644 index 000000000000..981a604cb391 --- /dev/null +++ b/tests/python/contrib/test_onnx_model.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Relay to ONNX serialization test cases""" +from collections import OrderedDict +import numpy as np +import onnxruntime as rt +import tvm +from tvm import relay +from tvm.contrib.codegen_onnx import to_onnx +import tvm.relay.testing +from tvm.relay.op.annotation import compiler_begin, compiler_end +from tvm.ir import IRModule +from tvm.relay import transform + + +def func_to_onnx(mod, params, name): + onnx_model = to_onnx(mod, params, name, path=None) + return onnx_model.SerializeToString() + + +def run_onnx(mod, params, name, input_data): + onnx_model = func_to_onnx(mod, params, name) + sess = rt.InferenceSession(onnx_model) + input_names = {} + for input, data in zip(sess.get_inputs(), input_data): + input_names[input.name] = data + output_names = [output.name for output in sess.get_outputs()] + res = sess.run(output_names, input_names) + return res[0] + + +def get_data(in_data_shapes, dtype='float32'): + in_data = OrderedDict() + for name, shape in in_data_shapes.items(): + in_data[name] = np.random.uniform(size=shape).astype(dtype) + return in_data + + +def run_relay(mod, params, in_data): + target = 'llvm' + ctx = tvm.context('llvm', 0) + intrp = relay.create_executor("graph", mod, ctx=ctx, target=target) + in_data = [tvm.nd.array(value) for value in in_data.values()] + return intrp.evaluate()(*in_data, **params).asnumpy() + + +def _verify_results(mod, params, in_data): + a = run_relay(mod, params, in_data) + b = run_onnx(mod, params, 'test_resent', in_data.values()) + np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-7) + + +def test_resnet(): + num_class = 1000 + in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) + in_data = get_data(in_data_shapes, dtype="float32") + for n in [18, 34, 50, 101]: + mod, params = tvm.relay.testing.resnet.get_workload( + 1, num_class, num_layers=n) + _verify_results(mod, params, in_data) + + +def test_squeezenet(): + in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) + in_data = get_data(in_data_shapes, dtype="float32") + for version in ['1.0', '1.1']: + mod, params = tvm.relay.testing.squeezenet.get_workload(1, version=version) + _verify_results(mod, params, in_data) + + +def test_partition(): + in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') + in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') + in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') + in_4 = relay.var('in_4', shape=(10, 10), dtype='float32') + in_5 = relay.var('in_5', shape=(10, 10), dtype='float32') + in_6 = relay.var('in_6', shape=(10, 10), dtype='float32') + in_7 = relay.var('in_7', shape=(10, 10), dtype='float32') + in_8 = relay.var('in_8', shape=(10, 10), dtype='float32') + in_9 = relay.var('in_9', shape=(10, 10), dtype='float32') + in_10 = relay.var('in_10', shape=(10, 10), dtype='float32') + + begin0 = compiler_begin(in_1, "onnx") + begin1 = compiler_begin(in_2, "onnx") + begin2 = compiler_begin(in_3, "onnx") + begin3 = compiler_begin(in_4, "onnx") + node0 = relay.add(begin0, begin1) + node1 = relay.add(begin2, begin3) + end0 = compiler_end(node0, "onnx") + end1 = compiler_end(node1, "onnx") + begin4 = compiler_begin(end0, "onnx") + begin5 = compiler_begin(end1, "onnx") + node2 = relay.add(begin4, begin5) + end2 = compiler_end(node2, "onnx") + + dbegin0 = compiler_begin(in_5, "default") + dbegin1 = compiler_begin(in_6, "default") + node3 = relay.subtract(dbegin0, dbegin1) + dbegin2 = compiler_begin(in_7, "default") + dend1 = compiler_end(node3, "default") + dbegin3 = compiler_begin(dend1, "default") + node4 = relay.subtract(dbegin2, dbegin3) + dend2 = compiler_end(node4, "default") + + begin6 = compiler_begin(end2, "onnx") + begin7 = compiler_begin(dend2, "onnx") + node5 = relay.add(begin6, begin7) + end3 = compiler_end(node5, "onnx") + end4 = compiler_end(node5, "onnx") + dbegin4 = compiler_begin(in_8, "default") + dbegin5 = compiler_begin(end3, "default") + node6 = relay.subtract(dbegin4, dbegin5) + begin8 = compiler_begin(in_9, "onnx") + begin9 = compiler_begin(end4, "onnx") + node7 = relay.multiply(begin8, begin9) + end5 = compiler_end(node7, "onnx") + + dend3 = compiler_end(node6, "default") + begin10 = compiler_begin(dend3, "onnx") + begin11 = compiler_begin(end5, "onnx") + node8 = relay.add(begin10, begin11) + end6 = compiler_end(node8, "onnx") + begin12 = compiler_begin(in_10, "onnx") + begin13 = compiler_begin(end6, "onnx") + node9 = relay.add(begin12, begin13) + end7 = compiler_end(node9, "onnx") + + func = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7) + + target = 'llvm' + mod = IRModule() + expr = func + mod["main"] = expr + mod = transform.PartitionGraph()(mod) + + with relay.build_config(opt_level=3, disabled_pass=['FuseOps']): + graph_json, mod1, params = relay.build(mod, target) + + assert mod1.type_key == "llvm" + assert mod1.imported_modules[0].type_key == "onnx" + assert mod1.imported_modules[0].get_source() + mod1.imported_modules[0].save("/Users/mahesh/tmp/", "onnx") + + +if __name__ == '__main__': + test_resnet() + test_squeezenet() diff --git a/tests/python/converter/test_model.py b/tests/python/converter/test_model.py deleted file mode 100644 index d8e982e58cca..000000000000 --- a/tests/python/converter/test_model.py +++ /dev/null @@ -1,85 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -"""Relay to ONNX serialization test cases""" -from collections import OrderedDict -import numpy as np -import tvm -from tvm import relay -from tvm.relay.converter import to_onnx -import onnxruntime as rt -import tvm.relay.testing - - -def func_to_onnx(mod, params, name): - onnx_model = to_onnx(mod, params, name, path=None) - return onnx_model.SerializeToString() - - -def run_onnx(mod, params, name, input_data): - onnx_model = func_to_onnx(mod, params, name) - sess = rt.InferenceSession(onnx_model) - input_names = {} - for input, data in zip(sess.get_inputs(), input_data): - input_names[input.name] = data - output_names = [output.name for output in sess.get_outputs()] - res = sess.run(output_names, input_names) - return res[0] - - -def get_data(in_data_shapes, dtype='float32'): - in_data = OrderedDict() - for name, shape in in_data_shapes.items(): - in_data[name] = np.random.uniform(size=shape).astype(dtype) - return in_data - - -def run_relay(mod, params, in_data): - target = 'llvm' - ctx = tvm.context('llvm', 0) - intrp = relay.create_executor("graph", mod, ctx=ctx, target=target) - in_data = [tvm.nd.array(value) for value in in_data.values()] - return intrp.evaluate()(*in_data, **params).asnumpy() - - -def _verify_results(mod, params, in_data): - a = run_relay(mod, params, in_data) - b = run_onnx(mod, params, 'test_resent', in_data.values()) - np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-7) - - -def test_resnet(): - num_class = 1000 - in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) - in_data = get_data(in_data_shapes, dtype="float32") - for n in [18, 34, 50, 101]: - mod, params = tvm.relay.testing.resnet.get_workload( - 1, num_class, num_layers=n) - _verify_results(mod, params, in_data) - - -def test_squeezenet(): - in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) - in_data = get_data(in_data_shapes, dtype="float32") - for version in ['1.0', '1.1']: - mod, params = tvm.relay.testing.squeezenet.get_workload(1, version=version) - _verify_results(mod, params, in_data) - - -if __name__ == '__main__': - test_resnet() - test_squeezenet() From 820d9f0aa0afc79549b66a4aa5b6e1b0fa5038bc Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 14 May 2020 06:26:24 +0530 Subject: [PATCH 07/28] ONNX codegen --- CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 18f58c8ccb9c..c2fc5fa3d3f2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,6 +69,7 @@ tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) tvm_option(USE_COREML "Build with coreml support" OFF) +tvm_option(USE_ONNX_CODEGEN "Build with ONNX Codegen support" OFF) if(USE_CPP_RPC AND UNIX) message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.") @@ -325,6 +326,7 @@ include(cmake/modules/contrib/HybridDump.cmake) include(cmake/modules/contrib/TFLite.cmake) include(cmake/modules/contrib/TF_TVMDSOOP.cmake) include(cmake/modules/contrib/CoreML.cmake) +include(cmake/modules/contrib/ONNX.cmake) include(CheckCXXCompilerFlag) if(NOT MSVC) From 1c0352815ff7845a7c002ecbeaf8d7fa4835db93 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 14 May 2020 07:02:48 +0530 Subject: [PATCH 08/28] ONNX codegen --- cmake/config.cmake | 3 ++ cmake/modules/contrib/ONNX.cmake | 2 +- .../contrib/onnx}/onnx_module.cc | 28 ++++++------------- 3 files changed, 13 insertions(+), 20 deletions(-) rename src/{target/source => runtime/contrib/onnx}/onnx_module.cc (74%) diff --git a/cmake/config.cmake b/cmake/config.cmake index 7e5734e952a9..2a2150abcce0 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -215,3 +215,6 @@ set(USE_TF_TVMDSOOP OFF) # Whether to use hexagon device set(USE_HEXAGON_DEVICE OFF) set(USE_HEXAGON_SDK /path/to/sdk) + +# Whether to use ONNX codegen +set(USE_ONNX_CODEGEN OFF) diff --git a/cmake/modules/contrib/ONNX.cmake b/cmake/modules/contrib/ONNX.cmake index c4a791b372b3..7d780a977dba 100644 --- a/cmake/modules/contrib/ONNX.cmake +++ b/cmake/modules/contrib/ONNX.cmake @@ -17,6 +17,6 @@ if(USE_ONNX_CODEGEN) message(STATUS "Build with contrib.codegen_onnx") - file(GLOB ONNX_CONTRIB_SRC src/target/source/onnx_module.cc) + file(GLOB ONNX_CONTRIB_SRC src/runtime/contrib/onnx/onnx_module.cc) list(APPEND RUNTIME_SRCS ${ONNX_CONTRIB_SRC}) endif(USE_ONNX_CODEGEN) diff --git a/src/target/source/onnx_module.cc b/src/runtime/contrib/onnx/onnx_module.cc similarity index 74% rename from src/target/source/onnx_module.cc rename to src/runtime/contrib/onnx/onnx_module.cc index 5b148dda9f32..96844e043e1c 100644 --- a/src/target/source/onnx_module.cc +++ b/src/runtime/contrib/onnx/onnx_module.cc @@ -23,21 +23,11 @@ */ #include #include -#include "codegen_source_base.h" -#include "../../runtime/file_util.h" -#include "../../runtime/meta_data.h" +#include namespace tvm { namespace codegen { - -using runtime::TVMArgs; -using runtime::TVMRetValue; -using runtime::PackedFunc; - -using runtime::GetFileFormat; -using runtime::GetMetaFilePath; -using runtime::FunctionInfo; -using runtime::SaveBinaryToFile; +using namespace tvm::runtime; class ONNXSourceModuleNode : public runtime::ModuleNode { public: @@ -49,11 +39,11 @@ class ONNXSourceModuleNode : public runtime::ModuleNode { } PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { - LOG(FATAL) << "ONNX Source module cannot execute, to get executable module" - << " build TVM with onnx runtime support"; - return PackedFunc(); + const std::string& name, + const ObjectPtr& sptr_to_self) final { + LOG(FATAL) << "ONNX Source module cannot execute, to get executable module" + << " build TVM with 'onnx' runtime support"; + return PackedFunc(); } std::string GetSource(const std::string& format) final { @@ -73,9 +63,9 @@ class ONNXSourceModuleNode : public runtime::ModuleNode { String code_; }; -runtime::Module ONNXSourceModuleNodeCreate(String code) { +Module ONNXSourceModuleNodeCreate(String code) { auto n = make_object(code); - return runtime::Module(n); + return Module(n); } TVM_REGISTER_GLOBAL("runtime.ONNXModuleCreate") From 26d1d87789e5b97345a1a6aa5109608953ef7108 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 14 May 2020 07:16:37 +0530 Subject: [PATCH 09/28] onnx testcases --- Jenkinsfile | 1 - tests/python/contrib/test_onnx_model.py | 4 +-- tests/scripts/task_python_converter.sh | 35 ------------------------- 3 files changed, 2 insertions(+), 38 deletions(-) delete mode 100755 tests/scripts/task_python_converter.sh diff --git a/Jenkinsfile b/Jenkinsfile index 43b42ad65307..60ee14249d28 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -254,7 +254,6 @@ stage('Integration Test') { unpack_lib('gpu', tvm_multilib) timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_frontend.sh" - sh "${docker_run} ${ci_gpu} ./tests/scripts/task_python_converter.sh" } } } diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py index 981a604cb391..a1ccbe75e3be 100644 --- a/tests/python/contrib/test_onnx_model.py +++ b/tests/python/contrib/test_onnx_model.py @@ -83,7 +83,7 @@ def test_squeezenet(): _verify_results(mod, params, in_data) -def test_partition(): +def skipped_test_partition(): in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') @@ -154,9 +154,9 @@ def test_partition(): assert mod1.type_key == "llvm" assert mod1.imported_modules[0].type_key == "onnx" assert mod1.imported_modules[0].get_source() - mod1.imported_modules[0].save("/Users/mahesh/tmp/", "onnx") if __name__ == '__main__': test_resnet() test_squeezenet() + # skipped_test_partition() diff --git a/tests/scripts/task_python_converter.sh b/tests/scripts/task_python_converter.sh deleted file mode 100755 index 7faccb2d6a4c..000000000000 --- a/tests/scripts/task_python_converter.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/bin/bash -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -set -e -set -u - -export PYTHONPATH=python:topi/python -# to avoid openblas threading error -export TVM_BIND_THREADS=0 -export OMP_NUM_THREADS=1 - -find . -type f -path "*.pyc" | xargs rm -f - -# Rebuild cython -make cython3 - -echo "Running relay to ONNX converter..." -python3 -m pytest -v tests/python/converter/ - - From 7849ec65032ce3058eb7c61d6926fd9a48eb87fe Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 14 May 2020 07:30:14 +0530 Subject: [PATCH 10/28] ONNX codegen --- src/runtime/contrib/onnx/onnx_module.cc | 28 ++++++++----------------- 1 file changed, 9 insertions(+), 19 deletions(-) diff --git a/src/runtime/contrib/onnx/onnx_module.cc b/src/runtime/contrib/onnx/onnx_module.cc index 96844e043e1c..69c3ccf7f1be 100644 --- a/src/runtime/contrib/onnx/onnx_module.cc +++ b/src/runtime/contrib/onnx/onnx_module.cc @@ -21,9 +21,9 @@ * \file onnx_module.cc * \brief ONNX Module without runtime support */ +#include #include #include -#include namespace tvm { namespace codegen { @@ -31,29 +31,20 @@ using namespace tvm::runtime; class ONNXSourceModuleNode : public runtime::ModuleNode { public: - explicit ONNXSourceModuleNode(String code) - : code_(code) {} + explicit ONNXSourceModuleNode(String code) : code_(code) {} - const char* type_key() const { - return "onnx"; - } + const char* type_key() const { return "onnx"; } - PackedFunc GetFunction( - const std::string& name, - const ObjectPtr& sptr_to_self) final { + PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { LOG(FATAL) << "ONNX Source module cannot execute, to get executable module" - << " build TVM with 'onnx' runtime support"; + << " build TVM with 'onnx' runtime support"; return PackedFunc(); } - std::string GetSource(const std::string& format) final { - return code_; - } + std::string GetSource(const std::string& format) final { return code_; } - void SaveToFile(const std::string& path, - const std::string& format) final { - CHECK_EQ(format, "onnx") - << "Can only save to onnx format"; + void SaveToFile(const std::string& path, const std::string& format) final { + CHECK_EQ(format, "onnx") << "Can only save to onnx format"; CHECK_NE(code_.length(), 0); const PackedFunc* to_onnx_ = runtime::Registry::Get("relay.ext.onnx.save_to_file"); (*to_onnx_)(code_, path, format); @@ -68,8 +59,7 @@ Module ONNXSourceModuleNodeCreate(String code) { return Module(n); } -TVM_REGISTER_GLOBAL("runtime.ONNXModuleCreate") -.set_body_typed(ONNXSourceModuleNodeCreate); +TVM_REGISTER_GLOBAL("runtime.ONNXModuleCreate").set_body_typed(ONNXSourceModuleNodeCreate); } // namespace codegen } // namespace tvm From 44d417e0315900376e71dc4dfe41dba9c419135c Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 14 May 2020 16:40:17 +0530 Subject: [PATCH 11/28] test onnx --- tests/python/contrib/test_onnx.py | 4 ++++ tests/python/contrib/test_onnx_model.py | 5 +++++ 2 files changed, 9 insertions(+) diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index a6f954d773eb..85dac997dc6d 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -16,6 +16,10 @@ # under the License. """Relay to ONNX serialization test cases""" +import pytest +pytest.importorskip('onnx') +pytest.importorskip('onnxruntime') + import numpy as np import tvm from tvm import relay diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py index a1ccbe75e3be..ecdf747abc97 100644 --- a/tests/python/contrib/test_onnx_model.py +++ b/tests/python/contrib/test_onnx_model.py @@ -16,6 +16,10 @@ # under the License. """Relay to ONNX serialization test cases""" +import pytest +pytest.importorskip('onnx') +pytest.importorskip('onnxruntime') + from collections import OrderedDict import numpy as np import onnxruntime as rt @@ -159,4 +163,5 @@ def skipped_test_partition(): if __name__ == '__main__': test_resnet() test_squeezenet() + # test_partition need USE_ONNX_CODEGEN enabled # skipped_test_partition() From 406807fcffcfd07593cfba1d31948dfc0afc2cca Mon Sep 17 00:00:00 2001 From: maheshambule Date: Tue, 19 May 2020 19:34:46 +0530 Subject: [PATCH 12/28] ONNX codegen --- python/tvm/contrib/codegen_onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/codegen_onnx.py b/python/tvm/contrib/codegen_onnx.py index a5f4c7119501..40ec3065ed5c 100644 --- a/python/tvm/contrib/codegen_onnx.py +++ b/python/tvm/contrib/codegen_onnx.py @@ -118,7 +118,7 @@ class MaxPool(OpConverter): @classmethod def convert_attributes(cls, attrs): return { - 'pads': attrs.get_int_tuple("padding") + attrs.get_int_tuple("padding"), + 'pads': attrs.get_int_tuple("padding"), 'strides': attrs.get_int_tuple("strides"), 'kernel_shape': attrs.get_int_tuple("pool_size"), } From ae6b7d1e3d27bc04e8f3a8c2cc7aff1951973da0 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Wed, 20 May 2020 17:41:47 +0530 Subject: [PATCH 13/28] shape calculation --- python/tvm/contrib/codegen_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/codegen_onnx.py b/python/tvm/contrib/codegen_onnx.py index 40ec3065ed5c..dfff433db095 100644 --- a/python/tvm/contrib/codegen_onnx.py +++ b/python/tvm/contrib/codegen_onnx.py @@ -441,9 +441,9 @@ def convert_attributes(cls, attrs): def convert(cls, node, model_container, node_list): attrs = cls.convert_attributes(node['node'].attrs) input_node = node_list[node['inputs'][0][0]] - shape = input_node['types'][0].shape - dtype = input_node['types'][0].dtype + dtype = input_node['node'].type_annotation.dtype input_shape_name = 'shape_{}'.format(node['output_names'][0]) + shape = [val.value for val in input_node['node'].type_annotation.shape] shape = numpy.asarray(shape).astype(numpy.int64) add_input(shape, input_shape_name, model_container) From 0370f413083da8e57e3972dcbd2d0521f728be7c Mon Sep 17 00:00:00 2001 From: maheshambule Date: Fri, 22 May 2020 12:04:51 +0530 Subject: [PATCH 14/28] move onnx codegen to contrib/target --- python/tvm/contrib/target/__init__.py | 18 ++++++++++++++++++ .../{codegen_onnx.py => target/onnx.py} | 2 +- tests/python/contrib/test_onnx.py | 2 +- tests/python/contrib/test_onnx_model.py | 2 +- 4 files changed, 21 insertions(+), 3 deletions(-) create mode 100644 python/tvm/contrib/target/__init__.py rename python/tvm/contrib/{codegen_onnx.py => target/onnx.py} (99%) diff --git a/python/tvm/contrib/target/__init__.py b/python/tvm/contrib/target/__init__.py new file mode 100644 index 000000000000..7d815413f28a --- /dev/null +++ b/python/tvm/contrib/target/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Codegen and runtime APIs for targets. +""" diff --git a/python/tvm/contrib/codegen_onnx.py b/python/tvm/contrib/target/onnx.py similarity index 99% rename from python/tvm/contrib/codegen_onnx.py rename to python/tvm/contrib/target/onnx.py index dfff433db095..12464ad8e5e3 100644 --- a/python/tvm/contrib/codegen_onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name, import-self, len-as-condition, unused-argument, too-many-lines, redefined-builtin -"""Relay to ONNX serialization """ +"""Relay to ONNX codegen """ import os import struct diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index 85dac997dc6d..c6dc23e7adaf 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -23,7 +23,7 @@ import numpy as np import tvm from tvm import relay -from tvm.contrib.codegen_onnx import to_onnx +from tvm.contrib.target.onnx import to_onnx import onnxruntime as rt diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py index ecdf747abc97..5e350113fdc7 100644 --- a/tests/python/contrib/test_onnx_model.py +++ b/tests/python/contrib/test_onnx_model.py @@ -25,7 +25,7 @@ import onnxruntime as rt import tvm from tvm import relay -from tvm.contrib.codegen_onnx import to_onnx +from tvm.contrib.target.onnx import to_onnx import tvm.relay.testing from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.ir import IRModule From 427815a871dcbe81eac4fc0678d67618381b4c07 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Fri, 22 May 2020 19:17:16 +0530 Subject: [PATCH 15/28] review comments --- python/tvm/contrib/target/onnx.py | 25 ++++++++++++++----------- tests/python/contrib/test_onnx.py | 13 ------------- 2 files changed, 14 insertions(+), 24 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index 12464ad8e5e3..f79535c9395f 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -143,7 +143,7 @@ def convert(cls, node, model_container, node_list): transpose_node = onnx.helper.make_node(Transpose.__name__, [node['input_names'][1]], [output_name], - **{'perm': (1, 0)}) + perm=(1, 0)) model_container.add_nodes([transpose_node]) inputs = [node['input_names'][0], output_name] @@ -188,21 +188,21 @@ def convert(cls, node, model_container, node_list): node_transposed = onnx.helper.make_node(Transpose.__name__, [node['input_names'][0]], [transpose_out_name], - **{'perm': [0, 3, 1, 2]}) + perm=[0, 3, 1, 2]) model_container.add_nodes([node_transposed]) output_names = ['batch_norm_{}'.format(node['output_names'][0])] batch_norm_node = onnx.helper.make_node(cls.__name__, [transpose_out_name] + node['input_names'][1:], output_names, - **{'epsilon': attrs['epsilon']}) + epsilon=attrs['epsilon']) model_container.add_nodes([batch_norm_node]) if attrs['axis'] == 3: node_transposed = onnx.helper.make_node(Transpose.__name__, output_names, node['output_names'], - **{'perm': [0, 2, 3, 1]}) + perm=[0, 2, 3, 1]) model_container.add_nodes([node_transposed]) @@ -250,7 +250,7 @@ def convert(cls, node, model_container, node_list): unsqueeze_node = onnx.helper.make_node('Unsqueeze', [node['input_names'][1]], [output_name], - **{'axes': tuple(range(1, new_axes + 1))}) + axes=tuple(range(1, new_axes + 1))) model_container.add_nodes([unsqueeze_node]) else: output_name = node['input_names'][1] @@ -286,8 +286,8 @@ def convert(cls, node, model_container, node_list): node = onnx.helper.make_node(cls.__name__, node['input_names'], node['output_names'], - **{"axes": axis, - "keepdims": keepdims}) + axes=axis, + keepdims=keepdims) model_container.add_nodes([node]) @@ -367,7 +367,7 @@ def convert(cls, node, model_container, node_list): node = onnx.helper.make_node(cls.__name__, node['input_names'], node['output_names'], - **{"axes": axis}) + axes=axis) model_container.add_nodes([node]) @@ -454,7 +454,7 @@ def convert(cls, node, model_container, node_list): node = onnx.helper.make_node('ConstantOfShape', [input_shape_name], node['output_names'], - **{'value': tensor_value}) + value=tensor_value) model_container.add_nodes([node]) @@ -672,13 +672,16 @@ def to_onnx(relay_ir, params, name, opset_version=11, path=None): params : dict dict of the parameter names and NDarray values + name : str + name of the output ONNX graph + path : str The path where ONNX model will be saved Returns ------- - inferred_model : tvm.relay.Module - The relay module + onnx_model : onnx.ModelProto + converted ONNX model as a ModelProto. """ diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index c6dc23e7adaf..e6550f10231f 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -218,19 +218,6 @@ def verify_test_batch_flatten(d_shape): verify_test_batch_flatten((1, 8)) -def test_bias_add(): - def verify_bias_add(): - data = relay.var("data", relay.TensorType((1, 16), "float32")) - bias = relay.var("bias", relay.TensorType((16,), "float32")) - func = relay.Function([data, bias], relay.nn.bias_add(data, bias)) - - x_data = np.random.uniform(size=(1, 16)).astype("float32") - bias = np.random.uniform(size=(16,)).astype("float32") - verify_results(func, [x_data, bias], 'test_bias_add', rtol=1e-5, atol=1e-5) - - verify_bias_add() - - def test_batch_norm(): def verify_batch_norm(axis=1): for dtype in ['float16', 'float32']: From f5955478b2b268d08bb66c91dbecc83351f95a26 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Sat, 6 Jun 2020 14:27:37 +0530 Subject: [PATCH 16/28] ONNX target use visitor --- python/tvm/contrib/target/onnx.py | 410 +++++++++++++++++++++--------- tests/python/contrib/test_onnx.py | 73 ++++-- 2 files changed, 341 insertions(+), 142 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index f79535c9395f..88d29b5bf8dc 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -24,9 +24,11 @@ import onnx.utils from onnx import numpy_helper, OperatorSetIdProto, defs import tvm +from tvm import relay import tvm._ffi -from tvm.autotvm.graph_tuner.utils.traverse_graph import _expr2graph_impl -from tvm.relay.expr import Call, TupleGetItem, Var, Constant, Tuple +from tvm.relay.expr_functor import ExprVisitor +from tvm.relay.ty import TupleType, TensorType + ONNX_OPSET_VERSONS_SUPPORTED = [11] @@ -39,6 +41,30 @@ def get_onnx_version(): return onnx.__version__ +def infer_type(node): + """A method to infer the type of a relay expression.""" + mod = tvm.IRModule.from_expr(node) + mod = relay.transform.InferType()(mod) + entry = mod["main"] + return entry if isinstance(node, relay.Function) else entry.body + + +def call_node_infer_type(node): + infer_out = infer_type(node) + out_type = infer_out._checked_type_ + types = [] + if isinstance(out_type, TensorType): + types.append(out_type) + elif isinstance(out_type, TupleType): + for tupe_type in out_type.fields: + types.append(tupe_type) + else: + raise RuntimeError("Unsupported output type %s in operator %s" + % (type(out_type), node.op.nae)) + + return types + + def add_input(data, name, model_container): dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[data.dtype] tensor_value_info = onnx.helper.make_tensor_value_info(name, dtype, shape=data.shape) @@ -61,13 +87,15 @@ def convert_attributes(cls, attrs): return {} @classmethod - def convert(cls, node, model_container, node_list): - attrs = cls.convert_attributes(node['node'].attrs) - node = onnx.helper.make_node(cls.__name__, - node['input_names'], - node['output_names'], + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + output_names = [node_entry['name']] + onnx_node = onnx.helper.make_node(cls.__name__, + node_entry['input_names'], + output_names, **attrs) - model_container.add_nodes([node]) + model_container.add_nodes([onnx_node]) + return output_names def rename(op_name): @@ -81,20 +109,21 @@ class Reshape(object): """ @classmethod - def convert(cls, node, model_container, node_list): + def convert(cls, node_entry, model_container, node_dict): """Converts Relay operator Reshape to ONNX operator. Relay operator accepts shape as attribute but ONNX operator accepts it as a input. """ - shape = numpy.asarray([a.value for a in node['node'].attrs.newshape], + shape = numpy.asarray([a.value for a in node_entry['relay_node'].attrs.newshape], dtype=numpy.int64) - input_name = 'shape{}'.format(node['output_names'][0]) - node = onnx.helper.make_node(cls.__name__, [node['input_names'][0], input_name], - node['output_names']) + output_name = node_entry['name'] + input_name = 'shape{}'.format(output_name) + node = onnx.helper.make_node(cls.__name__, [node_entry['input_names'][0], input_name], + [output_name]) model_container.add_nodes([node]) add_input(shape, input_name, model_container) - + return [output_name] class Conv(OpConverter): """ Operator converter for Conv. @@ -138,17 +167,19 @@ class MatMul(OpConverter): """ @classmethod - def convert(cls, node, model_container, node_list): - output_name = 'inter{}'.format(node['output_names'][0]) + def convert(cls, node_entry, model_container, node_dict): + output_name = node_entry['name'] + inter_output_name = 'inter{}'.format(output_name) transpose_node = onnx.helper.make_node(Transpose.__name__, - [node['input_names'][1]], - [output_name], + [node_entry['input_names'][1]], + [inter_output_name], perm=(1, 0)) model_container.add_nodes([transpose_node]) - inputs = [node['input_names'][0], output_name] - matmul_node = onnx.helper.make_node(cls.__name__, inputs, node['output_names']) + inputs = [node_entry['input_names'][0], inter_output_name] + matmul_node = onnx.helper.make_node(cls.__name__, inputs, [output_name]) model_container.add_nodes([matmul_node]) + return [output_name] class Flatten(OpConverter): @@ -174,37 +205,39 @@ def convert_attributes(cls, attrs): } @classmethod - def convert(cls, node, model_container, node_list): + def convert(cls, node_entry, model_container, node_dict): """Converts Relay operator batch_norm to ONNX operator. Relay operator has property axis to handle data in NHWC format. """ - attrs = cls.convert_attributes(node['node'].attrs) - transpose_out_name = node['input_names'][0] - output_names = node['output_names'] + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + transpose_out_name = node_entry['input_names'][0] + output_names = [node_entry['name']] # axis==3 means channel is specified along the 3rd axis if attrs['axis'] == 3: - transpose_out_name = 'transpose_{}'.format(node['output_names'][0]) + transpose_out_name = 'transpose_{}'.format(output_names[0]) node_transposed = onnx.helper.make_node(Transpose.__name__, - [node['input_names'][0]], + [node_entry['input_names'][0]], [transpose_out_name], perm=[0, 3, 1, 2]) model_container.add_nodes([node_transposed]) - output_names = ['batch_norm_{}'.format(node['output_names'][0])] + output_names = ['batch_norm_{}'.format(output_names[0])] batch_norm_node = onnx.helper.make_node(cls.__name__, - [transpose_out_name] + node['input_names'][1:], + [transpose_out_name] + node_entry['input_names'][1:], output_names, epsilon=attrs['epsilon']) model_container.add_nodes([batch_norm_node]) - + final_out_names = output_names if attrs['axis'] == 3: node_transposed = onnx.helper.make_node(Transpose.__name__, output_names, - node['output_names'], + [node_entry['name']], perm=[0, 2, 3, 1]) model_container.add_nodes([node_transposed]) + final_out_names = [node_entry['name']] + return final_out_names class Dropout(OpConverter): """ Operator converter for Dropout. @@ -238,26 +271,30 @@ class BiasAdd(OpConverter): """ @classmethod - def convert(cls, node, model_container, node_list): - input_node = node_list[node['inputs'][0][0]] + def convert(cls, node_entry, model_container, node_dict): + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node_entry can not be a Tuple" + input_node = input_node[0] data_ndim = len(input_node['types'][0].shape) - axis = node['node'].attrs.get_int("axis") + axis = node_entry['relay_node'].attrs.get_int("axis") if axis < 0: axis = axis + data_ndim new_axes = data_ndim - axis - 1 if new_axes: - output_name = 'inter{}'.format(node['output_names'][0]) + inter_output_name = 'inter{}'.format(node_entry['name']) unsqueeze_node = onnx.helper.make_node('Unsqueeze', - [node['input_names'][1]], - [output_name], + [node_entry['input_names'][1]], + [inter_output_name], axes=tuple(range(1, new_axes + 1))) model_container.add_nodes([unsqueeze_node]) else: - output_name = node['input_names'][1] + inter_output_name = node_entry['input_names'][1] - inputs = [node['input_names'][0], output_name] - matmul_node = onnx.helper.make_node('Add', inputs, node['output_names']) + inputs = [node_entry['input_names'][0], inter_output_name] + output_names = [node_entry['name']] + matmul_node = onnx.helper.make_node('Add', inputs, output_names) model_container.add_nodes([matmul_node]) + return output_names class ReduceMean(OpConverter): @@ -272,23 +309,27 @@ def convert_attributes(cls, attrs): } @classmethod - def convert(cls, node, model_container, node_list): - input_node = node_list[node['inputs'][0][0]] + def convert(cls, node_entry, model_container, node_dict): + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] shape = input_node['types'][0].shape - axis = node['node'].attrs.axis + axis = node_entry['relay_node'].attrs.axis axis = list(range(shape.size())) if not axis else tvm_array_to_list(axis) - exclude = 0 if not bool(node['node'].attrs.exclude) else 1 - keepdims = 0 if not bool(node['node'].attrs.keepdims) else 1 + exclude = 0 if not bool(node_entry['relay_node'].attrs.exclude) else 1 + keepdims = 0 if not bool(node_entry['relay_node'].attrs.keepdims) else 1 if exclude: all_axis = list(range(len(shape))) axis = set(all_axis) - set(axis) + output_names = [node_entry['name']] node = onnx.helper.make_node(cls.__name__, - node['input_names'], - node['output_names'], + node_entry['input_names'], + output_names, axes=axis, keepdims=keepdims) model_container.add_nodes([node]) + return output_names class Pad(OpConverter): @@ -311,23 +352,25 @@ def convert_attributes(cls, attrs): } @classmethod - def convert(cls, node, model_container, node_list): + def convert(cls, node_entry, model_container, node_dict): """Converts Relay operator Pad to ONNX operator. Relay operator accepts pads as attribute but ONNX operator accepts it as a input. """ - attrs = cls.convert_attributes(node['node'].attrs) + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + output_names = [node_entry['name']] data = numpy.asarray(attrs['pads'], dtype=attrs['pads'][0].dtype).astype(numpy.int64) - input_name = 'pads_{}'.format(node['output_names'][0]) - value = numpy.dtype(node['types'][0].dtype).type(attrs['constant_value']) - input_value_name = 'value_{}'.format(node['output_names'][0]) + input_name = 'pads_{}'.format(output_names[0]) + value = numpy.dtype(node_entry['types'][0].dtype).type(attrs['constant_value']) + input_value_name = 'value_{}'.format(output_names[0]) add_input(data, input_name, model_container) add_input(value, input_value_name, model_container) - input_names = [node['input_names'][0], input_name, input_value_name] - node = onnx.helper.make_node(cls.__name__, input_names, node['output_names']) + input_names = [node_entry['input_names'][0], input_name, input_value_name] + node = onnx.helper.make_node(cls.__name__, input_names, output_names) model_container.add_nodes([node]) + return output_names class Softmax(OpConverter): @@ -352,23 +395,27 @@ def convert_attributes(cls, attrs): } @classmethod - def convert(cls, node, model_container, node_list): - input_node = node_list[node['inputs'][0][0]] + def convert(cls, node_entry, model_container, node_dict): + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] shape = input_node['types'][0].shape - axis = node['node'].attrs.get_int("axis") + axis = node_entry['relay_node'].attrs.get_int("axis") if not axis: axis = [] for axis_idx, val in enumerate(shape): if val.value == 1: axis.append(axis_idx) else: - axis = node['node'].attrs.get_int_tuple("axis") + axis = node_entry['relay_node'].attrs.get_int_tuple("axis") + output_names = [node_entry['name']] node = onnx.helper.make_node(cls.__name__, - node['input_names'], - node['output_names'], + node_entry['input_names'], + output_names, axes=axis) model_container.add_nodes([node]) + return output_names class Slice(OpConverter): @@ -384,10 +431,12 @@ def convert_attributes(cls, attrs): } @classmethod - def convert(cls, node, model_container, node_list): - attrs = cls.convert_attributes(node['node'].attrs) + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) - input_node = node_list[node['inputs'][0][0]] + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] shape = input_node['types'][0].shape starts = list(attrs['starts']) ends = list(attrs['ends']) @@ -396,15 +445,16 @@ def convert(cls, node, model_container, node_list): for i in range(len(ends), len(shape)): ends.append(shape[i] + 1) + output_names = [node_entry['name']] starts = numpy.asarray(starts).astype(numpy.int64) - starts_name = 'starts_{}'.format(node['output_names'][0]) + starts_name = 'starts_{}'.format(output_names[0]) add_input(starts, starts_name, model_container) ends = numpy.asarray(ends).astype(numpy.int64) - ends_name = 'ends_{}'.format(node['output_names'][0]) + ends_name = 'ends_{}'.format(output_names[0]) add_input(ends, ends_name, model_container) - input_names = node['input_names'] + [starts_name, ends_name] + input_names = node_entry['input_names'] + [starts_name, ends_name] if attrs['steps']: axes = list(range(len(shape))) @@ -412,19 +462,72 @@ def convert(cls, node, model_container, node_list): assert len(axes) == len(attrs['steps']), "axes and steps should be of same size" steps = numpy.asarray(attrs['steps']).astype(numpy.int64) - steps_name = 'steps_{}'.format(node['output_names'][0]) + steps_name = 'steps_{}'.format(output_names[0]) add_input(steps, steps_name, model_container) axes = numpy.asarray(attrs['axes']).astype(numpy.int64) - axes_name = 'axes_{}'.format(node['output_names'][0]) + axes_name = 'axes_{}'.format(output_names[0]) add_input(axes, axes_name, model_container) input_names = input_names + [axes_name, steps_name] slice_node = onnx.helper.make_node(cls.__name__, input_names, - node['output_names']) + output_names) model_container.add_nodes([slice_node]) + return output_names + + +class Split(OpConverter): + """ Operator converter for Split. + """ + + @classmethod + def convert_attributes(cls, attrs): + indices_or_sections = attrs['indices_or_sections'] + + if isinstance(indices_or_sections, tvm.ir.container.Array)\ + or isinstance(indices_or_sections, list): + indices_or_sections = attrs.get_int_tuple('indices_or_sections') + elif isinstance(indices_or_sections, tvm.ir.PrimExpr): + indices_or_sections = indices_or_sections.value + + return { + 'indices_or_section': indices_or_sections, + 'axis': attrs.get_int('axis'), + } + + @classmethod + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] + shape = input_node['types'][0].concrete_shape + + indices = attrs["indices_or_section"] + axis = attrs["axis"] + axis_len = shape[axis] + + if isinstance(indices, int): + split = [axis_len // indices] * (indices) + else: + split = [None] * (len(indices) + 1) + split[0] = indices[0] + for i in range(len(indices) - 1): + split[i + 1] = indices[i + 1] - indices[i] + split[-1] = axis_len - indices[-1] + + output_names = ["{}_{}".format(node_entry['name'], i) for i in range(len(split))] + slice_node = onnx.helper.make_node(cls.__name__, + node_entry['input_names'], + output_names, + split=split, + axis=axis) + model_container.add_nodes([slice_node]) + + return output_names class ConstantOfShapeZeros(OpConverter): @@ -438,12 +541,15 @@ def convert_attributes(cls, attrs): } @classmethod - def convert(cls, node, model_container, node_list): - attrs = cls.convert_attributes(node['node'].attrs) - input_node = node_list[node['inputs'][0][0]] - dtype = input_node['node'].type_annotation.dtype - input_shape_name = 'shape_{}'.format(node['output_names'][0]) - shape = [val.value for val in input_node['node'].type_annotation.shape] + def convert(cls, node_entry, model_container, node_dict): + attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + input_node = node_dict[node_entry['inputs'][0]] + assert len(input_node) == 1, "input node can not be a Tuple" + input_node = input_node[0] + dtype = input_node['relay_node'].type_annotation.dtype + output_names = [node_entry['name']] + input_shape_name = 'shape_{}'.format(output_names[0]) + shape = [val.value for val in input_node['relay_node'].type_annotation.shape] shape = numpy.asarray(shape).astype(numpy.int64) add_input(shape, input_shape_name, model_container) @@ -453,9 +559,10 @@ def convert(cls, node, model_container, node_list): node = onnx.helper.make_node('ConstantOfShape', [input_shape_name], - node['output_names'], + output_names, value=tensor_value) model_container.add_nodes([node]) + return output_names class ConstantOfShapeOnes(ConstantOfShapeZeros): @@ -496,7 +603,8 @@ def convert_attributes(cls, attrs): 'equal': rename('Equal'), 'zeros_like': ConstantOfShapeZeros, 'ones_like': ConstantOfShapeOnes, - 'subtract': rename('Sub') + 'subtract': rename('Sub'), + 'split': Split } @@ -543,12 +651,12 @@ def make_model(self): kwargs = {} kwargs["opset_imports"] = self._get_opsets() kwargs["producer_name"] = 'TVM Relay' - kwargs["producer_name"] = tvm.__version__ + kwargs["producer_version"] = tvm.__version__ return onnx.helper.make_model(onnx_graph, **kwargs) -class RelayToONNXConverter(object): +class RelayToONNXConverter(ExprVisitor): """A helper class converting topologically sorted Relay nodes to ONNX model Parameters @@ -556,44 +664,100 @@ class RelayToONNXConverter(object): name : str name of the model - node_list : list - topologically sorted Relay Node entry list + params : dict + dict of the parameter names and NDarray values + + opset_version : int + target onnx opset version + """ - def __init__(self, name, node_list, params, opset_version): + def __init__(self, name, params, opset_version): + super().__init__() self._name = {} self._mc = ModelContainer(name, opset_version) - self._node_list = node_list self._params = params + self._node_dict = {} + self._node_count = 0 + self.last_node = None - def convert_to_onnx(self): + @classmethod + def _get_node_entry(cls, relay_node, name, node_index): + return {"relay_node": relay_node, + "inputs": [relay_node], # inputs in the form of relay nodes + "types": [], # output types in case of call nodes else self type + "name": name, # name of the node + "input_names": [name], # input names in case of call nodes else self name + "output_names": [name], # output names in case of call nodes else self name + "op": None, # op name in case of call node else None + "index": node_index + } + + def convert_to_onnx(self, func): """ Loop through topologically sorted list of Relay nodes and generate a ONNX model""" - for idx, node_entry in enumerate(self._node_list): - out_idx = idx - node = node_entry['node'] - if isinstance(node, Call): - self._add_node(node_entry, idx) - elif isinstance(node, Var): - self._add_input(node_entry, idx) - elif isinstance(node, Constant): - self._add_constant_input(node_entry, idx) - elif isinstance(node, (TupleGetItem, Tuple)): - out_idx = idx - 1 # TODO: Need to work on this. - # No equivalent ONNX operator found yet - else: - raise NotImplementedError("Relay Node of type {0} is not " - "implemented yet".format(type(node))) - - if idx == len(self._node_list) - 1: - self._add_output(self._node_list[out_idx], out_idx) + self.visit(func) + + self._add_output(self._node_dict[self.last_node]) model = self._mc.make_model() polished_model = onnx.utils.polish_model(model) return polished_model - def _tuple_to_name(self, input): - """convert tuple of node indexes to string""" - return 'node_{0}'.format(input[0]) + def visit(self, expr): + self._node_count += 1 + super().visit(expr) + + def visit_constant(self, const): + node_index = self._node_count + name = "Constant_" + str(node_index) + node_entry = self._get_node_entry(const, name, node_index) + node_entry["types"] = [const.checked_type] + + self._add_constant_input(node_entry, node_index) + self._node_dict[const] = [node_entry] + + def visit_var(self, var): + node_index = self._node_count + node_entry = self._get_node_entry(var, var.name_hint, node_index) + node_entry["types"] = [var.type_annotation] + + self._add_input(node_entry, node_index) + self._node_dict[var] = [node_entry] + + def visit_tuple(self, tup): + self._node_dict[tup] = [] + for f in tup.fields: + self.visit(f) + self._node_dict[tup].extend(self._node_dict[f]) + + def visit_tuple_getitem(self, tup): + self.visit(tup.tuple_value) + self._node_dict[tup] = [self._node_dict[tup.tuple_value][tup.index]] + self.last_node = tup + + def visit_call(self, call_node): + node_index = self._node_count + op = call_node.op + name = "{}_{}".format(op, node_index) + node_entry = self._get_node_entry(call_node, name, node_index) + + node_entry["op"] = op + node_entry["input_names"] = [] + node_entry["inputs"] = [] + node_entry["output_names"] = None + for input_arg in call_node.args: + self.visit(input_arg) + input_names = [] + for arg_node_entry in self._node_dict[input_arg]: + input_names.extend(arg_node_entry["output_names"]) + node_entry["input_names"].extend(input_names) + node_entry["inputs"].extend([input_arg]) + + node_entry['types'] = call_node_infer_type(call_node) + self.last_node = call_node + output_names = self._add_node(node_entry, node_index) + node_entry["output_names"] = output_names if output_names is not None else [name] + self._node_dict[call_node] = [node_entry] def _add_node(self, node_entry, idx): """Convert Relay operator node to ONNX operator and add it to container nodes list""" @@ -602,15 +766,8 @@ def _add_node(self, node_entry, idx): "not supported.".format(node_entry['op'].name)) converter = relay_to_onnx_op_mapping[node_entry['op'].name]() - node_entry['output_names'] = [self._tuple_to_name([idx, 0, 0])] - node_entry['input_names'] = [] - for input_idx_tuple in node_entry['inputs']: - if self._node_list[input_idx_tuple[0]]['name']: - node_entry['input_names'].append(self._node_list[input_idx_tuple[0]]['name']) - else: - node_entry['input_names'].append(self._tuple_to_name(input_idx_tuple)) - converter.convert(node_entry, self._mc, self._node_list) + return converter.convert(node_entry, self._mc, self._node_dict) def _add_params(self, node_entry, idx): """Add param value to initializer and name to inputs""" @@ -631,9 +788,7 @@ def _add_constant_input(self, node_entry, idx): """Create named input for constant and add it to container inputs. If input is a parameter then add to param """ - node = node_entry['node'] - if not node_entry['name']: - node_entry['name'] = self._tuple_to_name([idx, 0, 0]) + node = node_entry['relay_node'] param_name = node_entry['name'] self._params[param_name] = node.data self._add_params(node_entry, idx) @@ -650,15 +805,16 @@ def _add_input(self, node_entry, idx): shape=type.concrete_shape) self._mc.add_inputs([input]) - def _add_output(self, node_entry, idx): + def _add_output(self, node_entries): """Add output node to container outputs.""" - type = node_entry['types'][0] - dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(type.dtype)] - output = onnx.helper.make_tensor_value_info(self._tuple_to_name([idx, 0, 0]), - dtype, - shape=type.concrete_shape) - self._mc.add_outputs([output]) + for node_entry in node_entries: + for type, output_name in zip(node_entry['types'], node_entry['output_names']): + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(type.dtype)] + output = onnx.helper.make_tensor_value_info(output_name, + dtype, + shape=type.concrete_shape) + self._mc.add_outputs([output]) def to_onnx(relay_ir, params, name, opset_version=11, path=None): @@ -675,6 +831,9 @@ def to_onnx(relay_ir, params, name, opset_version=11, path=None): name : str name of the output ONNX graph + opset_version : int + target onnx opset version + path : str The path where ONNX model will be saved @@ -693,12 +852,9 @@ def to_onnx(relay_ir, params, name, opset_version=11, path=None): "version {}. Upgrade the ONNX package to latest version.".format( get_onnx_version(), opset_version)) - node_list = [] # ONNX needs a topologically sorted list of nodes - node_dict = {} func = relay_ir["main"] if isinstance(relay_ir, tvm.ir.IRModule) else relay_ir - _expr2graph_impl(func, [], node_dict, node_list) - converter = RelayToONNXConverter(name, node_list, params, opset_version) - onnx_model = converter.convert_to_onnx() + converter = RelayToONNXConverter(name, params, opset_version) + onnx_model = converter.convert_to_onnx(func) if path: onnx.save(onnx_model, path) diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index e6550f10231f..59f05a70bd00 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -21,10 +21,12 @@ pytest.importorskip('onnxruntime') import numpy as np +import onnxruntime as rt + import tvm from tvm import relay from tvm.contrib.target.onnx import to_onnx -import onnxruntime as rt + def func_to_onnx(func, name): @@ -39,9 +41,9 @@ def run_onnx(onnx_model, input_data): input_names = {} for input, data in zip(sess.get_inputs(), input_data): input_names[input.name] = data - output_name = sess.get_outputs()[0].name - res = sess.run([output_name], input_names) - return res[0] + output_names = [out.name for out in sess.get_outputs()] + res = sess.run(output_names, input_names) + return res def run_relay(func, data_tuple): @@ -49,13 +51,21 @@ def run_relay(func, data_tuple): ctx = tvm.context('llvm', 0) intrp = relay.create_executor("graph", ctx=ctx, target=target) relay_res = intrp.evaluate(func)(*data_tuple) - return relay_res.asnumpy() + + result = [] + relay_res = relay_res if isinstance(relay_res, list) else [relay_res] + for res in relay_res: + result.append(res.asnumpy()) + + return result def verify_results(relay_func, indata, test_name, rtol=1e-7, atol=0): - relay_res = run_relay(relay_func, indata) - onnx_res = run_onnx(func_to_onnx(relay_func, test_name), indata) - np.testing.assert_allclose(relay_res, onnx_res, rtol=rtol, atol=atol) + relay_results = run_relay(relay_func, indata) + onnx_results = run_onnx(func_to_onnx(relay_func, test_name), indata) + + for relay_res, onnx_res in zip(relay_results, onnx_results): + np.testing.assert_allclose(relay_res, onnx_res, rtol=rtol, atol=atol) def test_add(): @@ -297,6 +307,45 @@ def verify_mean(data_shape, axis, exclude, keepdims): verify_mean((3, 2, 1), 1, False, True) +def test_split(): + def verify_split(dshape, indices_or_sections, axis=None): + dtype = "float32" + x = relay.var("x", relay.ty.TensorType(dshape, "float32")) + y = relay.split(x, indices_or_sections, axis=axis) + func = relay.Function([x], y.astuple()) + x_data = np.random.uniform(size=dshape).astype(dtype) + + verify_results(func, [x_data], 'test_split', rtol=1e-5, atol=1e-5) + + verify_split((5, 5, 2, 2), 5, axis=1) + verify_split((5, 5, 2, 2), 5, axis=0) + verify_split((5, 5, 2, 2), [1, 3, 4], axis=0) + verify_split((5, 5, 2, 2), [1, 3, 4], axis=1) + + +def test_concatenate(): + def verify_concatenate(shapes, axis, dtype="float32"): + in_vars = [] + in_data = [] + for i, shape in enumerate(shapes): + in_vars.append(relay.var("x"+ str(i), relay.ty.TensorType(shape, dtype))) + in_data.append(np.random.uniform(size=shape).astype(dtype)) + + out_tensor = relay.concatenate(in_vars, axis) + func = relay.Function(in_vars, out_tensor) + verify_results(func, in_data, 'test_split', rtol=1e-5, atol=1e-5) + + verify_concatenate([(2,), (2,), (2,)], -1) + verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1) + verify_concatenate([(1, 2, 4), (1, 2, 3), (1, 2, 7), (1, 2, 8), (1, 2, 1)], -1) + verify_concatenate([(5, 6, 7, 3), + (16, 6, 7, 3), + (12, 6, 7, 3), + (8, 6, 7, 3), + (2, 6, 7, 3)], 0) + verify_concatenate([(1, 14400), (1, 2400), (1, 640), (1, 240)], 1) + + def test_strided_slice(): def verify_strided_slice(dshape, begin, end, strides): x = relay.var("x", relay.TensorType(dshape, "float32")) @@ -315,12 +364,7 @@ def verify_strided_slice(dshape, begin, end, strides): verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3], None) verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None) verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None) - - # TODO - test cases below fails for TVM itself error -strided_slice get empty slice at axis 1 - # verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], [1, -1, 1]) - # verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], [1, 1, 2]) - # verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], [1, -1, 1]) - # verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], [1, 1, 2]) + verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], [1, 1, 2]) def test_cmp_type(): @@ -382,7 +426,6 @@ def check_binary_op(opfunc, dtype): test_dense() test_max_pool() test_batch_flatten() - test_bias_add() test_batch_norm() test_pad() test_mean() From 6daf93faa69223b81391ba734d74113a4c6511cb Mon Sep 17 00:00:00 2001 From: maheshambule Date: Sat, 6 Jun 2020 18:55:26 +0530 Subject: [PATCH 17/28] onnx fixes --- python/tvm/contrib/target/onnx.py | 154 ++++++++++++++---------------- tests/python/contrib/test_onnx.py | 35 ++++++- 2 files changed, 107 insertions(+), 82 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index 88d29b5bf8dc..afa589c7e7d7 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -19,6 +19,7 @@ import os import struct +import copy import numpy import onnx import onnx.utils @@ -29,7 +30,6 @@ from tvm.relay.expr_functor import ExprVisitor from tvm.relay.ty import TupleType, TensorType - ONNX_OPSET_VERSONS_SUPPORTED = [11] @@ -50,6 +50,7 @@ def infer_type(node): def call_node_infer_type(node): + """infer the output types of call node""" infer_out = infer_type(node) out_type = infer_out._checked_type_ types = [] @@ -89,13 +90,11 @@ def convert_attributes(cls, attrs): @classmethod def convert(cls, node_entry, model_container, node_dict): attrs = cls.convert_attributes(node_entry['relay_node'].attrs) - output_names = [node_entry['name']] onnx_node = onnx.helper.make_node(cls.__name__, - node_entry['input_names'], - output_names, - **attrs) + node_entry['input_names'], + node_entry['output_names'], + **attrs) model_container.add_nodes([onnx_node]) - return output_names def rename(op_name): @@ -117,13 +116,12 @@ def convert(cls, node_entry, model_container, node_dict): shape = numpy.asarray([a.value for a in node_entry['relay_node'].attrs.newshape], dtype=numpy.int64) - output_name = node_entry['name'] - input_name = 'shape{}'.format(output_name) + input_name = 'shape{}'.format(node_entry['name']) node = onnx.helper.make_node(cls.__name__, [node_entry['input_names'][0], input_name], - [output_name]) + node_entry['output_names']) model_container.add_nodes([node]) add_input(shape, input_name, model_container) - return [output_name] + class Conv(OpConverter): """ Operator converter for Conv. @@ -168,8 +166,7 @@ class MatMul(OpConverter): @classmethod def convert(cls, node_entry, model_container, node_dict): - output_name = node_entry['name'] - inter_output_name = 'inter{}'.format(output_name) + inter_output_name = 'inter{}'.format(node_entry['name']) transpose_node = onnx.helper.make_node(Transpose.__name__, [node_entry['input_names'][1]], [inter_output_name], @@ -177,9 +174,8 @@ def convert(cls, node_entry, model_container, node_dict): model_container.add_nodes([transpose_node]) inputs = [node_entry['input_names'][0], inter_output_name] - matmul_node = onnx.helper.make_node(cls.__name__, inputs, [output_name]) + matmul_node = onnx.helper.make_node(cls.__name__, inputs, node_entry['output_names']) model_container.add_nodes([matmul_node]) - return [output_name] class Flatten(OpConverter): @@ -211,33 +207,31 @@ def convert(cls, node_entry, model_container, node_dict): """ attrs = cls.convert_attributes(node_entry['relay_node'].attrs) transpose_out_name = node_entry['input_names'][0] - output_names = [node_entry['name']] - + inter_output_names = [node_entry['output_names'][0]] # axis==3 means channel is specified along the 3rd axis if attrs['axis'] == 3: - transpose_out_name = 'transpose_{}'.format(output_names[0]) + transpose_out_name = 'transpose_{}'.format(node_entry['name']) node_transposed = onnx.helper.make_node(Transpose.__name__, [node_entry['input_names'][0]], [transpose_out_name], perm=[0, 3, 1, 2]) model_container.add_nodes([node_transposed]) - output_names = ['batch_norm_{}'.format(output_names[0])] + inter_output_names = ['batch_norm_{}'.format(node_entry['name'])] + input_names = [transpose_out_name] + node_entry['input_names'][1:] batch_norm_node = onnx.helper.make_node(cls.__name__, - [transpose_out_name] + node_entry['input_names'][1:], - output_names, + input_names, + inter_output_names, epsilon=attrs['epsilon']) model_container.add_nodes([batch_norm_node]) - final_out_names = output_names + if attrs['axis'] == 3: node_transposed = onnx.helper.make_node(Transpose.__name__, - output_names, - [node_entry['name']], + inter_output_names, + [node_entry['output_names'][0]], perm=[0, 2, 3, 1]) model_container.add_nodes([node_transposed]) - final_out_names = [node_entry['name']] - return final_out_names class Dropout(OpConverter): """ Operator converter for Dropout. @@ -291,10 +285,8 @@ def convert(cls, node_entry, model_container, node_dict): inter_output_name = node_entry['input_names'][1] inputs = [node_entry['input_names'][0], inter_output_name] - output_names = [node_entry['name']] - matmul_node = onnx.helper.make_node('Add', inputs, output_names) + matmul_node = onnx.helper.make_node('Add', inputs, node_entry['output_names']) model_container.add_nodes([matmul_node]) - return output_names class ReduceMean(OpConverter): @@ -322,14 +314,12 @@ def convert(cls, node_entry, model_container, node_dict): all_axis = list(range(len(shape))) axis = set(all_axis) - set(axis) - output_names = [node_entry['name']] node = onnx.helper.make_node(cls.__name__, node_entry['input_names'], - output_names, + node_entry['output_names'], axes=axis, keepdims=keepdims) model_container.add_nodes([node]) - return output_names class Pad(OpConverter): @@ -359,18 +349,17 @@ def convert(cls, node_entry, model_container, node_dict): """ attrs = cls.convert_attributes(node_entry['relay_node'].attrs) - output_names = [node_entry['name']] + name = node_entry['name'] data = numpy.asarray(attrs['pads'], dtype=attrs['pads'][0].dtype).astype(numpy.int64) - input_name = 'pads_{}'.format(output_names[0]) + input_name = 'pads_{}'.format(name) value = numpy.dtype(node_entry['types'][0].dtype).type(attrs['constant_value']) - input_value_name = 'value_{}'.format(output_names[0]) + input_value_name = 'value_{}'.format(name) add_input(data, input_name, model_container) add_input(value, input_value_name, model_container) input_names = [node_entry['input_names'][0], input_name, input_value_name] - node = onnx.helper.make_node(cls.__name__, input_names, output_names) + node = onnx.helper.make_node(cls.__name__, input_names, node_entry['output_names']) model_container.add_nodes([node]) - return output_names class Softmax(OpConverter): @@ -409,13 +398,11 @@ def convert(cls, node_entry, model_container, node_dict): else: axis = node_entry['relay_node'].attrs.get_int_tuple("axis") - output_names = [node_entry['name']] node = onnx.helper.make_node(cls.__name__, node_entry['input_names'], - output_names, + node_entry['output_names'], axes=axis) model_container.add_nodes([node]) - return output_names class Slice(OpConverter): @@ -445,13 +432,13 @@ def convert(cls, node_entry, model_container, node_dict): for i in range(len(ends), len(shape)): ends.append(shape[i] + 1) - output_names = [node_entry['name']] + name = node_entry['name'] starts = numpy.asarray(starts).astype(numpy.int64) - starts_name = 'starts_{}'.format(output_names[0]) + starts_name = 'starts_{}'.format(name) add_input(starts, starts_name, model_container) ends = numpy.asarray(ends).astype(numpy.int64) - ends_name = 'ends_{}'.format(output_names[0]) + ends_name = 'ends_{}'.format(name) add_input(ends, ends_name, model_container) input_names = node_entry['input_names'] + [starts_name, ends_name] @@ -462,20 +449,19 @@ def convert(cls, node_entry, model_container, node_dict): assert len(axes) == len(attrs['steps']), "axes and steps should be of same size" steps = numpy.asarray(attrs['steps']).astype(numpy.int64) - steps_name = 'steps_{}'.format(output_names[0]) + steps_name = 'steps_{}'.format(name) add_input(steps, steps_name, model_container) axes = numpy.asarray(attrs['axes']).astype(numpy.int64) - axes_name = 'axes_{}'.format(output_names[0]) + axes_name = 'axes_{}'.format(name) add_input(axes, axes_name, model_container) input_names = input_names + [axes_name, steps_name] slice_node = onnx.helper.make_node(cls.__name__, input_names, - output_names) + node_entry['output_names']) model_container.add_nodes([slice_node]) - return output_names class Split(OpConverter): @@ -486,10 +472,9 @@ class Split(OpConverter): def convert_attributes(cls, attrs): indices_or_sections = attrs['indices_or_sections'] - if isinstance(indices_or_sections, tvm.ir.container.Array)\ - or isinstance(indices_or_sections, list): + if isinstance(indices_or_sections, (list, tvm.ir.container.Array)): indices_or_sections = attrs.get_int_tuple('indices_or_sections') - elif isinstance(indices_or_sections, tvm.ir.PrimExpr): + if isinstance(indices_or_sections, tvm.ir.PrimExpr): indices_or_sections = indices_or_sections.value return { @@ -506,29 +491,29 @@ def convert(cls, node_entry, model_container, node_dict): input_node = input_node[0] shape = input_node['types'][0].concrete_shape - indices = attrs["indices_or_section"] + indices_or_sect = attrs["indices_or_section"] axis = attrs["axis"] - axis_len = shape[axis] + axis_length = shape[axis] - if isinstance(indices, int): - split = [axis_len // indices] * (indices) + if isinstance(indices_or_sect, int): + split = [axis_length // indices_or_sect] * indices_or_sect else: - split = [None] * (len(indices) + 1) - split[0] = indices[0] - for i in range(len(indices) - 1): - split[i + 1] = indices[i + 1] - indices[i] - split[-1] = axis_len - indices[-1] + split = [] + for i in range(len(indices_or_sect) + 1): + if i == 0: + split.append(indices_or_sect[0]) + elif i == len(indices_or_sect): + split.append(axis_length - indices_or_sect[-1]) + else: + split.append(indices_or_sect[i] - indices_or_sect[i - 1]) - output_names = ["{}_{}".format(node_entry['name'], i) for i in range(len(split))] slice_node = onnx.helper.make_node(cls.__name__, node_entry['input_names'], - output_names, + node_entry['output_names'], split=split, axis=axis) model_container.add_nodes([slice_node]) - return output_names - class ConstantOfShapeZeros(OpConverter): """ Operator converter for ConstantOfShape. @@ -547,8 +532,7 @@ def convert(cls, node_entry, model_container, node_dict): assert len(input_node) == 1, "input node can not be a Tuple" input_node = input_node[0] dtype = input_node['relay_node'].type_annotation.dtype - output_names = [node_entry['name']] - input_shape_name = 'shape_{}'.format(output_names[0]) + input_shape_name = 'shape_{}'.format(node_entry['name']) shape = [val.value for val in input_node['relay_node'].type_annotation.shape] shape = numpy.asarray(shape).astype(numpy.int64) add_input(shape, input_shape_name, model_container) @@ -559,10 +543,9 @@ def convert(cls, node_entry, model_container, node_dict): node = onnx.helper.make_node('ConstantOfShape', [input_shape_name], - output_names, + node_entry['output_names'], value=tensor_value) model_container.add_nodes([node]) - return output_names class ConstantOfShapeOnes(ConstantOfShapeZeros): @@ -730,22 +713,31 @@ def visit_tuple(self, tup): self.visit(f) self._node_dict[tup].extend(self._node_dict[f]) - def visit_tuple_getitem(self, tup): - self.visit(tup.tuple_value) - self._node_dict[tup] = [self._node_dict[tup.tuple_value][tup.index]] self.last_node = tup - def visit_call(self, call_node): + def visit_tuple_getitem(self, t): + self.visit(t.tuple_value) + tup_node = self._node_dict[t.tuple_value] + if len(tup_node) > 1: + self._node_dict[t] = tup_node[t.index] + else: + node_entry = copy.deepcopy(tup_node[0]) + output_names = [node_entry["output_names"][t.index]] + node_entry["output_names"] = output_names + self._node_dict[t] = [node_entry] + self.last_node = t + + def visit_call(self, call): node_index = self._node_count - op = call_node.op + op = call.op name = "{}_{}".format(op, node_index) - node_entry = self._get_node_entry(call_node, name, node_index) + node_entry = self._get_node_entry(call, name, node_index) node_entry["op"] = op node_entry["input_names"] = [] node_entry["inputs"] = [] node_entry["output_names"] = None - for input_arg in call_node.args: + for input_arg in call.args: self.visit(input_arg) input_names = [] for arg_node_entry in self._node_dict[input_arg]: @@ -753,11 +745,13 @@ def visit_call(self, call_node): node_entry["input_names"].extend(input_names) node_entry["inputs"].extend([input_arg]) - node_entry['types'] = call_node_infer_type(call_node) - self.last_node = call_node - output_names = self._add_node(node_entry, node_index) - node_entry["output_names"] = output_names if output_names is not None else [name] - self._node_dict[call_node] = [node_entry] + node_entry['types'] = call_node_infer_type(call) + node_entry["output_names"] = [] + for i in range(len(node_entry['types'])): + node_entry["output_names"].append(name + str(i)) + self.last_node = call + self._add_node(node_entry, node_index) + self._node_dict[call] = [node_entry] def _add_node(self, node_entry, idx): """Convert Relay operator node to ONNX operator and add it to container nodes list""" @@ -850,7 +844,7 @@ def to_onnx(relay_ir, params, name, opset_version=11, path=None): if opset_version > defs.onnx_opset_version(): raise Exception("The ONNX package installed of version {} does not support the opset " "version {}. Upgrade the ONNX package to latest version.".format( - get_onnx_version(), opset_version)) + get_onnx_version(), opset_version)) func = relay_ir["main"] if isinstance(relay_ir, tvm.ir.IRModule) else relay_ir converter = RelayToONNXConverter(name, params, opset_version) @@ -900,7 +894,7 @@ def save_to_file(hex_str, path=None, fmt="onnx"): while offset < len(onnx_ir): stop = offset + 4 (name_size,) = struct.unpack('I', onnx_ir[offset:stop]) - name = onnx_ir[stop : stop + name_size].decode("utf-8") + name = onnx_ir[stop:stop + name_size].decode("utf-8") stop = stop + name_size (model_size,) = struct.unpack('I', onnx_ir[stop:stop + 4]) stop = stop + 4 diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index 59f05a70bd00..c24104a6f363 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -328,12 +328,12 @@ def verify_concatenate(shapes, axis, dtype="float32"): in_vars = [] in_data = [] for i, shape in enumerate(shapes): - in_vars.append(relay.var("x"+ str(i), relay.ty.TensorType(shape, dtype))) + in_vars.append(relay.var("x" + str(i), relay.ty.TensorType(shape, dtype))) in_data.append(np.random.uniform(size=shape).astype(dtype)) out_tensor = relay.concatenate(in_vars, axis) func = relay.Function(in_vars, out_tensor) - verify_results(func, in_data, 'test_split', rtol=1e-5, atol=1e-5) + verify_results(func, in_data, 'test_concatenate', rtol=1e-5, atol=1e-5) verify_concatenate([(2,), (2,), (2,)], -1) verify_concatenate([(2, 3, 4), (2, 2, 4), (2, 5, 4)], 1) @@ -417,6 +417,34 @@ def check_binary_op(opfunc, dtype): for dtype in ['float32']: check_binary_op(opfunc, dtype) + +def test_tuple_types(): + def verify_tuple_types(dshape, indices_or_sections, axis=None, dtype = "float32"): + x = relay.var("x", relay.ty.TensorType(dshape, dtype)) + y = relay.split(x, indices_or_sections, axis=axis) + z = relay.concatenate(y, axis=axis) + func = relay.Function([x], z) + x_data = np.random.uniform(size=dshape).astype(dtype) + verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5) + + split_z = relay.split(z, indices_or_sections, axis=axis) + func = relay.Function([x], split_z.astuple()) + verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5) + + out = relay.Tuple([y[0] + y[1], y[0] - y[1]]) + func = relay.Function([x], out) + verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5) + + z = relay.concatenate(out, axis=axis) + func = relay.Function([x], z) + verify_results(func, [x_data], 'test_tuple_types', rtol=1e-5, atol=1e-5) + + verify_tuple_types((5, 5, 2, 2), 5, axis=1) + verify_tuple_types((5, 5, 2, 2), 5, axis=0) + verify_tuple_types((5, 5, 2, 2), [1, 3, 4], axis=0) + verify_tuple_types((5, 5, 2, 2), [1, 3, 4], axis=1) + + if __name__ == '__main__': test_add() test_bias_add() @@ -429,8 +457,11 @@ def check_binary_op(opfunc, dtype): test_batch_norm() test_pad() test_mean() + test_split() + test_concatenate() test_sofmax() test_squeeze() test_strided_slice() test_cmp_type() test_binary_op() + test_tuple_types() From 00060d69a9b937d7fb33c87f98653ef556c93650 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Sat, 6 Jun 2020 19:02:38 +0530 Subject: [PATCH 18/28] lint fixes --- python/tvm/contrib/target/onnx.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index afa589c7e7d7..073cfb17b1be 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -844,7 +844,7 @@ def to_onnx(relay_ir, params, name, opset_version=11, path=None): if opset_version > defs.onnx_opset_version(): raise Exception("The ONNX package installed of version {} does not support the opset " "version {}. Upgrade the ONNX package to latest version.".format( - get_onnx_version(), opset_version)) + get_onnx_version(), opset_version)) func = relay_ir["main"] if isinstance(relay_ir, tvm.ir.IRModule) else relay_ir converter = RelayToONNXConverter(name, params, opset_version) From 1dd1c8b132b882c0e2342d93334f21305f84d1e4 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Sat, 6 Jun 2020 19:07:32 +0530 Subject: [PATCH 19/28] doc string changes --- python/tvm/contrib/target/onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index 073cfb17b1be..e18cec875feb 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -640,7 +640,7 @@ def make_model(self): class RelayToONNXConverter(ExprVisitor): - """A helper class converting topologically sorted Relay nodes to ONNX model + """A helper class to traverse the Relay graph and convert Relay nodes to ONNX model Parameters ---------- @@ -677,7 +677,7 @@ def _get_node_entry(cls, relay_node, name, node_index): } def convert_to_onnx(self, func): - """ Loop through topologically sorted list of Relay nodes and generate a ONNX model""" + """ Traverse Relay graph and generate a ONNX model""" self.visit(func) From ec9a39537227d0d6f9075e7fc3841da6d9f58e38 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Thu, 11 Jun 2020 14:00:23 +0530 Subject: [PATCH 20/28] review comments --- python/tvm/contrib/target/onnx.py | 70 ++++++++++++++----------------- tests/python/contrib/test_onnx.py | 32 +++++++------- 2 files changed, 50 insertions(+), 52 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index e18cec875feb..5c334cb76b3b 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -53,12 +53,10 @@ def call_node_infer_type(node): """infer the output types of call node""" infer_out = infer_type(node) out_type = infer_out._checked_type_ - types = [] if isinstance(out_type, TensorType): - types.append(out_type) + types = [out_type] elif isinstance(out_type, TupleType): - for tupe_type in out_type.fields: - types.append(tupe_type) + types = list(out_type.fields) else: raise RuntimeError("Unsupported output type %s in operator %s" % (type(out_type), node.op.nae)) @@ -414,49 +412,47 @@ def convert_attributes(cls, attrs): return { 'starts': attrs.get_int_tuple('begin'), 'ends': attrs.get_int_tuple('end'), - 'steps': attrs.get_int_tuple('strides') + 'steps': attrs.get_int_tuple('strides'), + 'slice_mode': attrs.get_str('slice_mode') } @classmethod def convert(cls, node_entry, model_container, node_dict): attrs = cls.convert_attributes(node_entry['relay_node'].attrs) + name = node_entry['name'] input_node = node_dict[node_entry['inputs'][0]] assert len(input_node) == 1, "input node can not be a Tuple" input_node = input_node[0] shape = input_node['types'][0].shape + starts = list(attrs['starts']) ends = list(attrs['ends']) - for i in range(len(starts), len(shape)): - starts.append(0) - for i in range(len(ends), len(shape)): - ends.append(shape[i] + 1) - - name = node_entry['name'] - starts = numpy.asarray(starts).astype(numpy.int64) - starts_name = 'starts_{}'.format(name) - add_input(starts, starts_name, model_container) - - ends = numpy.asarray(ends).astype(numpy.int64) - ends_name = 'ends_{}'.format(name) - add_input(ends, ends_name, model_container) - - input_names = node_entry['input_names'] + [starts_name, ends_name] - - if attrs['steps']: - axes = list(range(len(shape))) - attrs['axes'] = axes - assert len(axes) == len(attrs['steps']), "axes and steps should be of same size" + steps = list(attrs['steps']) + starts += [0] * (len(shape) - len(starts)) + ends += [shape[i] + 1 for i in range(len(ends), len(shape))] + axes = list(range(len(shape))) + + if attrs['slice_mode'] == 'size': + ends = [starts[i] + (shape[i] + 1 if ends[i] < 0 else ends[i]) + for i in range(len(shape))] + steps = [1] * len(shape) + else: + steps += [1] * (len(shape) - len(steps)) - steps = numpy.asarray(attrs['steps']).astype(numpy.int64) - steps_name = 'steps_{}'.format(name) - add_input(steps, steps_name, model_container) + def _add_input(val, input_name): + val_arr = numpy.asarray(val).astype(numpy.int64) + input_name = '{}_{}'.format(name, input_name) + add_input(val_arr, input_name, model_container) + return input_name - axes = numpy.asarray(attrs['axes']).astype(numpy.int64) - axes_name = 'axes_{}'.format(name) - add_input(axes, axes_name, model_container) + input_names = [] + input_names.append(_add_input(starts, 'starts')) + input_names.append(_add_input(ends, 'ends')) + input_names.append(_add_input(axes, 'axes')) + input_names.append(_add_input(steps, 'steps')) - input_names = input_names + [axes_name, steps_name] + input_names = [node_entry['input_names'][0]] + input_names slice_node = onnx.helper.make_node(cls.__name__, input_names, @@ -665,7 +661,7 @@ def __init__(self, name, params, opset_version): self.last_node = None @classmethod - def _get_node_entry(cls, relay_node, name, node_index): + def _get_node_entry(cls, relay_node, name): return {"relay_node": relay_node, "inputs": [relay_node], # inputs in the form of relay nodes "types": [], # output types in case of call nodes else self type @@ -673,14 +669,12 @@ def _get_node_entry(cls, relay_node, name, node_index): "input_names": [name], # input names in case of call nodes else self name "output_names": [name], # output names in case of call nodes else self name "op": None, # op name in case of call node else None - "index": node_index } def convert_to_onnx(self, func): """ Traverse Relay graph and generate a ONNX model""" self.visit(func) - self._add_output(self._node_dict[self.last_node]) model = self._mc.make_model() polished_model = onnx.utils.polish_model(model) @@ -693,7 +687,7 @@ def visit(self, expr): def visit_constant(self, const): node_index = self._node_count name = "Constant_" + str(node_index) - node_entry = self._get_node_entry(const, name, node_index) + node_entry = self._get_node_entry(const, name) node_entry["types"] = [const.checked_type] self._add_constant_input(node_entry, node_index) @@ -701,7 +695,7 @@ def visit_constant(self, const): def visit_var(self, var): node_index = self._node_count - node_entry = self._get_node_entry(var, var.name_hint, node_index) + node_entry = self._get_node_entry(var, var.name_hint) node_entry["types"] = [var.type_annotation] self._add_input(node_entry, node_index) @@ -731,7 +725,7 @@ def visit_call(self, call): node_index = self._node_count op = call.op name = "{}_{}".format(op, node_index) - node_entry = self._get_node_entry(call, name, node_index) + node_entry = self._get_node_entry(call, name) node_entry["op"] = op node_entry["input_names"] = [] diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index c24104a6f363..2f2eb5a9518d 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -28,7 +28,6 @@ from tvm.contrib.target.onnx import to_onnx - def func_to_onnx(func, name): mod = tvm.IRModule() mod['main'] = func @@ -347,24 +346,29 @@ def verify_concatenate(shapes, axis, dtype="float32"): def test_strided_slice(): - def verify_strided_slice(dshape, begin, end, strides): + def verify_strided_slice(dshape, begin, end, strides, mode): x = relay.var("x", relay.TensorType(dshape, "float32")) - z = relay.strided_slice(x, begin=begin, end=end, strides=strides) + if mode == 'size': + strides = None + z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=mode) func = relay.Function([x], z) x_data = np.random.uniform(size=dshape).astype("float32") verify_results(func, [x_data], 'test_strided_slice', rtol=1e-5, atol=1e-5) - verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], None) - verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2]) - verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1]) - verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1]) - verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2]) - verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1]) - - verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3], None) - verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None) - verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None) - verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], [1, 1, 2]) + for mode in ['end', 'size']: + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 2, 3], None, mode) + verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -1, 3], [1, 2], mode) + verify_strided_slice((3, 4, 3), [1, ], [4, -3], None, mode) + verify_strided_slice((3, 4, 3), [0, 0, 0], [4, -5, 4], [1, -1, 2], mode) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, -3], [2, 1, 1], mode) + verify_strided_slice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], mode) + verify_strided_slice((3, 4, 3), [1, 0, 0], [2, 2, 3], [1, 1, 2], mode) + verify_strided_slice((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], mode) + + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, mode) + verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4], None, mode) + verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], None, mode) + verify_strided_slice((3, 4, 3), [1, 1], [4, 4, 3], [1, 1, 2], mode) def test_cmp_type(): From f37b551ec54d2606c79e132594aa2c4c5f9c6d2b Mon Sep 17 00:00:00 2001 From: maheshambule Date: Sat, 13 Jun 2020 00:56:13 +0530 Subject: [PATCH 21/28] review comment fixes --- CMakeLists.txt | 2 +- cmake/config.cmake | 2 +- cmake/modules/contrib/ONNX.cmake | 4 ++-- python/tvm/contrib/target/__init__.py | 2 ++ tests/python/contrib/test_onnx_model.py | 6 ++---- 5 files changed, 8 insertions(+), 8 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index c2fc5fa3d3f2..e1a43817540e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -69,7 +69,7 @@ tvm_option(USE_CPP_RPC "Build CPP RPC" OFF) tvm_option(USE_TFLITE "Build with tflite support" OFF) tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none) tvm_option(USE_COREML "Build with coreml support" OFF) -tvm_option(USE_ONNX_CODEGEN "Build with ONNX Codegen support" OFF) +tvm_option(USE_TARGET_ONNX "Build with ONNX Codegen support" OFF) if(USE_CPP_RPC AND UNIX) message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.") diff --git a/cmake/config.cmake b/cmake/config.cmake index 2a2150abcce0..67be29670536 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -217,4 +217,4 @@ set(USE_HEXAGON_DEVICE OFF) set(USE_HEXAGON_SDK /path/to/sdk) # Whether to use ONNX codegen -set(USE_ONNX_CODEGEN OFF) +set(USE_TARGET_ONNX OFF) diff --git a/cmake/modules/contrib/ONNX.cmake b/cmake/modules/contrib/ONNX.cmake index 7d780a977dba..2462980ccc83 100644 --- a/cmake/modules/contrib/ONNX.cmake +++ b/cmake/modules/contrib/ONNX.cmake @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -if(USE_ONNX_CODEGEN) +if(USE_TARGET_ONNX) message(STATUS "Build with contrib.codegen_onnx") file(GLOB ONNX_CONTRIB_SRC src/runtime/contrib/onnx/onnx_module.cc) list(APPEND RUNTIME_SRCS ${ONNX_CONTRIB_SRC}) -endif(USE_ONNX_CODEGEN) +endif(USE_TARGET_ONNX) diff --git a/python/tvm/contrib/target/__init__.py b/python/tvm/contrib/target/__init__.py index 7d815413f28a..385275eeeed2 100644 --- a/python/tvm/contrib/target/__init__.py +++ b/python/tvm/contrib/target/__init__.py @@ -16,3 +16,5 @@ # under the License. """Codegen and runtime APIs for targets. """ + +from . import onnx \ No newline at end of file diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py index 5e350113fdc7..e4404fe7dd04 100644 --- a/tests/python/contrib/test_onnx_model.py +++ b/tests/python/contrib/test_onnx_model.py @@ -147,9 +147,7 @@ def skipped_test_partition(): func = relay.Function([in_1, in_2, in_3, in_4, in_5, in_6, in_7, in_8, in_9, in_10], end7) target = 'llvm' - mod = IRModule() - expr = func - mod["main"] = expr + mod = IRModule.from_expr(func) mod = transform.PartitionGraph()(mod) with relay.build_config(opt_level=3, disabled_pass=['FuseOps']): @@ -163,5 +161,5 @@ def skipped_test_partition(): if __name__ == '__main__': test_resnet() test_squeezenet() - # test_partition need USE_ONNX_CODEGEN enabled + # test_partition need USE_TARGET_ONNX enabled # skipped_test_partition() From e76fb519405c04179ff59752b84cdda51cd7dc2a Mon Sep 17 00:00:00 2001 From: maheshambule Date: Sat, 13 Jun 2020 01:01:00 +0530 Subject: [PATCH 22/28] review comment --- python/tvm/contrib/target/__init__.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/tvm/contrib/target/__init__.py b/python/tvm/contrib/target/__init__.py index 385275eeeed2..13a83393a912 100644 --- a/python/tvm/contrib/target/__init__.py +++ b/python/tvm/contrib/target/__init__.py @@ -14,7 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Codegen and runtime APIs for targets. -""" - -from . import onnx \ No newline at end of file From 89acc7d85c5ff93b1733acf05faaad77472d2025 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Sat, 13 Jun 2020 01:22:10 +0530 Subject: [PATCH 23/28] pytest skip --- tests/python/contrib/test_onnx_model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py index e4404fe7dd04..013fbec43b78 100644 --- a/tests/python/contrib/test_onnx_model.py +++ b/tests/python/contrib/test_onnx_model.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -"""Relay to ONNX serialization test cases""" +"""Relay to ONNX target test cases""" import pytest pytest.importorskip('onnx') pytest.importorskip('onnxruntime') @@ -87,7 +87,8 @@ def test_squeezenet(): _verify_results(mod, params, in_data) -def skipped_test_partition(): +@pytest.mark.skip("USE_TARGET_ONNX should be ON") +def test_partition(): in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') @@ -161,5 +162,5 @@ def skipped_test_partition(): if __name__ == '__main__': test_resnet() test_squeezenet() - # test_partition need USE_TARGET_ONNX enabled - # skipped_test_partition() + # test_partition needs USE_TARGET_ONNX to be ON + # test_partition() From 8d1df2caed9e3e6d43839bab69284e4cd8a92506 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Tue, 23 Jun 2020 13:36:09 +0530 Subject: [PATCH 24/28] rename type to node type --- python/tvm/contrib/target/onnx.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index 5c334cb76b3b..808f5b171c47 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -786,22 +786,22 @@ def _add_input(self, node_entry, idx): if node_entry['name'] in self._params: self._add_params(node_entry, idx) else: - type = node_entry['types'][0] - dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(type.dtype)] + node_type = node_entry['types'][0] + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)] input = onnx.helper.make_tensor_value_info(node_entry['name'], dtype, - shape=type.concrete_shape) + shape=node_type.concrete_shape) self._mc.add_inputs([input]) def _add_output(self, node_entries): """Add output node to container outputs.""" for node_entry in node_entries: - for type, output_name in zip(node_entry['types'], node_entry['output_names']): - dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(type.dtype)] + for node_type, output_name in zip(node_entry['types'], node_entry['output_names']): + dtype = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[numpy.dtype(node_type.dtype)] output = onnx.helper.make_tensor_value_info(output_name, dtype, - shape=type.concrete_shape) + shape=node_type.concrete_shape) self._mc.add_outputs([output]) From 2c0112665c2bd8c3637485a288748850d6c95ef2 Mon Sep 17 00:00:00 2001 From: maheshambule Date: Sun, 12 Jul 2020 11:13:08 +0530 Subject: [PATCH 25/28] test --- python/tvm/contrib/target/onnx.py | 48 +++++++++++++------------ src/runtime/contrib/onnx/onnx_module.cc | 24 +++++++++---- tests/python/contrib/test_onnx_model.py | 32 ++++++++++++----- 3 files changed, 67 insertions(+), 37 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index 808f5b171c47..6b8acbc0e83d 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -527,9 +527,9 @@ def convert(cls, node_entry, model_container, node_dict): input_node = node_dict[node_entry['inputs'][0]] assert len(input_node) == 1, "input node can not be a Tuple" input_node = input_node[0] - dtype = input_node['relay_node'].type_annotation.dtype + dtype = input_node['types'][0].dtype input_shape_name = 'shape_{}'.format(node_entry['name']) - shape = [val.value for val in input_node['relay_node'].type_annotation.shape] + shape = [val.value for val in input_node['types'][0].shape] shape = numpy.asarray(shape).astype(numpy.int64) add_input(shape, input_shape_name, model_container) @@ -659,6 +659,7 @@ def __init__(self, name, params, opset_version): self._node_dict = {} self._node_count = 0 self.last_node = None + self.list_nodes = set() @classmethod def _get_node_entry(cls, relay_node, name): @@ -675,6 +676,8 @@ def convert_to_onnx(self, func): """ Traverse Relay graph and generate a ONNX model""" self.visit(func) + print("nodelist {}".format(self.list_nodes)) + print("nodelist bad {}".format(self.list_nodes - relay_to_onnx_op_mapping.keys())) self._add_output(self._node_dict[self.last_node]) model = self._mc.make_model() polished_model = onnx.utils.polish_model(model) @@ -749,13 +752,14 @@ def visit_call(self, call): def _add_node(self, node_entry, idx): """Convert Relay operator node to ONNX operator and add it to container nodes list""" - if node_entry['op'].name not in relay_to_onnx_op_mapping: - raise NotImplementedError("Currently the operator '{0}' is " - "not supported.".format(node_entry['op'].name)) + # if node_entry['op'].name not in relay_to_onnx_op_mapping: + # raise NotImplementedError("Currently the operator '{0}' is " + # "not supported.".format(node_entry['op'].name)) - converter = relay_to_onnx_op_mapping[node_entry['op'].name]() - - return converter.convert(node_entry, self._mc, self._node_dict) + self.list_nodes.add(node_entry['op'].name) + # converter = relay_to_onnx_op_mapping[node_entry['op'].name]() + # + # return converter.convert(node_entry, self._mc, self._node_dict) def _add_params(self, node_entry, idx): """Add param value to initializer and name to inputs""" @@ -850,27 +854,25 @@ def to_onnx(relay_ir, params, name, opset_version=11, path=None): @tvm._ffi.register_func("relay.ext.onnx") -def onnx_compiler(ref): - """Create a runtime module for ONNX from IRModule +def onnx_compiler(func): + """Create a runtime module for ONNX from Relay Function - :param ref: IRModule subgraphs for onnx codegen + :param func: Relay function :return: runtime module for ONNX """ - data = b'' - if isinstance(ref, tvm.ir.module.IRModule): - for var, func in ref.functions.items(): - name = var.name_hint - model = to_onnx(func, {}, name) - name_bytes = bytes(name, 'utf-8') - name_size = struct.pack('I', len(name_bytes)) - model_serialized = model.SerializeToString() - model_size = struct.pack('I', model.ByteSize()) - - data += name_size + name_bytes + model_size + model_serialized + + assert isinstance(func, tvm.relay.function.Function) + name = str(func.attrs.global_symbol) + model = to_onnx(func, {}, name) + name_bytes = bytes(name, 'utf-8') + name_size = struct.pack('I', len(name_bytes)) + model_serialized = model.SerializeToString() + model_size = struct.pack('I', model.ByteSize()) + data = b'' + name_size + name_bytes + model_size + model_serialized runtime_func = "runtime.ONNXModuleCreate" fcreate = tvm._ffi.get_global_func(runtime_func) - return fcreate(data.hex()) + return fcreate(data.hex(), name, []) @tvm._ffi.register_func("relay.ext.onnx.save_to_file") diff --git a/src/runtime/contrib/onnx/onnx_module.cc b/src/runtime/contrib/onnx/onnx_module.cc index 69c3ccf7f1be..67b379d91991 100644 --- a/src/runtime/contrib/onnx/onnx_module.cc +++ b/src/runtime/contrib/onnx/onnx_module.cc @@ -31,14 +31,23 @@ using namespace tvm::runtime; class ONNXSourceModuleNode : public runtime::ModuleNode { public: - explicit ONNXSourceModuleNode(String code) : code_(code) {} - + explicit ONNXSourceModuleNode(const std::string& code, const std::string& symbol, + const Array& const_vars) + : code_(code), symbol_(symbol), const_vars_(const_vars) {} const char* type_key() const { return "onnx"; } PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final { - LOG(FATAL) << "ONNX Source module cannot execute, to get executable module" + if (name == "get_symbol") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->symbol_; }); + } else if (name == "get_const_vars") { + return PackedFunc( + [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; }); + } else { + LOG(FATAL) << "ONNX Source module cannot execute, to get executable module" << " build TVM with 'onnx' runtime support"; - return PackedFunc(); + return PackedFunc(nullptr); + } } std::string GetSource(const std::string& format) final { return code_; } @@ -52,10 +61,13 @@ class ONNXSourceModuleNode : public runtime::ModuleNode { protected: String code_; + std::string symbol_; + Array const_vars_; }; -Module ONNXSourceModuleNodeCreate(String code) { - auto n = make_object(code); +Module ONNXSourceModuleNodeCreate(const String& code, const String& symbol, + const Array& const_vars) { + auto n = make_object(code.operator std::string(), symbol.operator std::string(), const_vars); return Module(n); } diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py index 013fbec43b78..81dc1b897864 100644 --- a/tests/python/contrib/test_onnx_model.py +++ b/tests/python/contrib/test_onnx_model.py @@ -69,7 +69,7 @@ def _verify_results(mod, params, in_data): np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-7) -def test_resnet(): +def atest_resnet(): num_class = 1000 in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) in_data = get_data(in_data_shapes, dtype="float32") @@ -79,7 +79,7 @@ def test_resnet(): _verify_results(mod, params, in_data) -def test_squeezenet(): +def atest_squeezenet(): in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) in_data = get_data(in_data_shapes, dtype="float32") for version in ['1.0', '1.1']: @@ -87,8 +87,8 @@ def test_squeezenet(): _verify_results(mod, params, in_data) -@pytest.mark.skip("USE_TARGET_ONNX should be ON") -def test_partition(): +#@pytest.mark.skip("USE_TARGET_ONNX should be ON") +def atest_partition(): in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') @@ -151,7 +151,7 @@ def test_partition(): mod = IRModule.from_expr(func) mod = transform.PartitionGraph()(mod) - with relay.build_config(opt_level=3, disabled_pass=['FuseOps']): + with tvm.transform.PassContext(opt_level=3, disabled_pass=['FuseOps']): graph_json, mod1, params = relay.build(mod, target) assert mod1.type_key == "llvm" @@ -159,8 +159,24 @@ def test_partition(): assert mod1.imported_modules[0].get_source() +def test_sample_model(): + import json + with open("/Users/mahesh/Downloads/alex.json") as f: + a = json.load(f) + + ir = tvm.ir.load_json(a) + + print(ir) + + a= func_to_onnx(ir, {}, "alex") + + pass + + if __name__ == '__main__': - test_resnet() - test_squeezenet() - # test_partition needs USE_TARGET_ONNX to be ON + # test_resnet() + # test_squeezenet() + # # test_partition needs USE_TARGET_ONNX to be ON # test_partition() + + test_sample_model() From 840cf7eae8f2fb39183eef80145b181e960720de Mon Sep 17 00:00:00 2001 From: Mahesh Ambule Date: Sun, 12 Jul 2020 21:30:20 +0530 Subject: [PATCH 26/28] Fix for constantshpae, add exp, fix for metadatamodule --- python/tvm/contrib/target/onnx.py | 25 ++++++++--------- tests/python/contrib/test_onnx_model.py | 37 ++++++++----------------- 2 files changed, 23 insertions(+), 39 deletions(-) diff --git a/python/tvm/contrib/target/onnx.py b/python/tvm/contrib/target/onnx.py index 6b8acbc0e83d..7f6945a8802c 100644 --- a/python/tvm/contrib/target/onnx.py +++ b/python/tvm/contrib/target/onnx.py @@ -583,7 +583,8 @@ def convert_attributes(cls, attrs): 'zeros_like': ConstantOfShapeZeros, 'ones_like': ConstantOfShapeOnes, 'subtract': rename('Sub'), - 'split': Split + 'split': Split, + 'exp': rename('Exp') } @@ -653,13 +654,12 @@ class RelayToONNXConverter(ExprVisitor): def __init__(self, name, params, opset_version): super().__init__() - self._name = {} + self._name = name self._mc = ModelContainer(name, opset_version) self._params = params self._node_dict = {} self._node_count = 0 self.last_node = None - self.list_nodes = set() @classmethod def _get_node_entry(cls, relay_node, name): @@ -676,8 +676,6 @@ def convert_to_onnx(self, func): """ Traverse Relay graph and generate a ONNX model""" self.visit(func) - print("nodelist {}".format(self.list_nodes)) - print("nodelist bad {}".format(self.list_nodes - relay_to_onnx_op_mapping.keys())) self._add_output(self._node_dict[self.last_node]) model = self._mc.make_model() polished_model = onnx.utils.polish_model(model) @@ -689,7 +687,7 @@ def visit(self, expr): def visit_constant(self, const): node_index = self._node_count - name = "Constant_" + str(node_index) + name = self._name + "_const_" + str(node_index) node_entry = self._get_node_entry(const, name) node_entry["types"] = [const.checked_type] @@ -752,14 +750,12 @@ def visit_call(self, call): def _add_node(self, node_entry, idx): """Convert Relay operator node to ONNX operator and add it to container nodes list""" - # if node_entry['op'].name not in relay_to_onnx_op_mapping: - # raise NotImplementedError("Currently the operator '{0}' is " - # "not supported.".format(node_entry['op'].name)) + if node_entry['op'].name not in relay_to_onnx_op_mapping: + raise NotImplementedError("Currently the operator '{0}' is " + "not supported.".format(node_entry['op'].name)) + converter = relay_to_onnx_op_mapping[node_entry['op'].name]() - self.list_nodes.add(node_entry['op'].name) - # converter = relay_to_onnx_op_mapping[node_entry['op'].name]() - # - # return converter.convert(node_entry, self._mc, self._node_dict) + return converter.convert(node_entry, self._mc, self._node_dict) def _add_params(self, node_entry, idx): """Add param value to initializer and name to inputs""" @@ -864,6 +860,7 @@ def onnx_compiler(func): assert isinstance(func, tvm.relay.function.Function) name = str(func.attrs.global_symbol) model = to_onnx(func, {}, name) + const_vars = [const.name for const in model.graph.initializer] name_bytes = bytes(name, 'utf-8') name_size = struct.pack('I', len(name_bytes)) model_serialized = model.SerializeToString() @@ -872,7 +869,7 @@ def onnx_compiler(func): runtime_func = "runtime.ONNXModuleCreate" fcreate = tvm._ffi.get_global_func(runtime_func) - return fcreate(data.hex(), name, []) + return fcreate(data.hex(), name, const_vars) @tvm._ffi.register_func("relay.ext.onnx.save_to_file") diff --git a/tests/python/contrib/test_onnx_model.py b/tests/python/contrib/test_onnx_model.py index 81dc1b897864..8766c0d229d4 100644 --- a/tests/python/contrib/test_onnx_model.py +++ b/tests/python/contrib/test_onnx_model.py @@ -69,7 +69,7 @@ def _verify_results(mod, params, in_data): np.testing.assert_allclose(a, b, rtol=1e-7, atol=1e-7) -def atest_resnet(): +def test_resnet(): num_class = 1000 in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) in_data = get_data(in_data_shapes, dtype="float32") @@ -79,7 +79,7 @@ def atest_resnet(): _verify_results(mod, params, in_data) -def atest_squeezenet(): +def test_squeezenet(): in_data_shapes = OrderedDict({"data": (1, 3, 224, 224)}) in_data = get_data(in_data_shapes, dtype="float32") for version in ['1.0', '1.1']: @@ -87,8 +87,8 @@ def atest_squeezenet(): _verify_results(mod, params, in_data) -#@pytest.mark.skip("USE_TARGET_ONNX should be ON") -def atest_partition(): +@pytest.mark.skip("USE_TARGET_ONNX should be ON") +def test_partition(): in_1 = relay.var('in_1', shape=(10, 10), dtype='float32') in_2 = relay.var('in_2', shape=(10, 10), dtype='float32') in_3 = relay.var('in_3', shape=(10, 10), dtype='float32') @@ -154,29 +154,16 @@ def atest_partition(): with tvm.transform.PassContext(opt_level=3, disabled_pass=['FuseOps']): graph_json, mod1, params = relay.build(mod, target) - assert mod1.type_key == "llvm" - assert mod1.imported_modules[0].type_key == "onnx" + assert mod1.type_key == "metadata" + assert mod1.imported_modules[0].type_key == "llvm" assert mod1.imported_modules[0].get_source() - - -def test_sample_model(): - import json - with open("/Users/mahesh/Downloads/alex.json") as f: - a = json.load(f) - - ir = tvm.ir.load_json(a) - - print(ir) - - a= func_to_onnx(ir, {}, "alex") - - pass + assert mod1.imported_modules[1].type_key == "onnx" + assert mod1.imported_modules[1].get_source() if __name__ == '__main__': - # test_resnet() - # test_squeezenet() - # # test_partition needs USE_TARGET_ONNX to be ON - # test_partition() + test_resnet() + test_squeezenet() + # test_partition needs USE_TARGET_ONNX to be ON + test_partition() - test_sample_model() From 028e0cf01d996c370d6b4766b281b5ef14a13883 Mon Sep 17 00:00:00 2001 From: Mahesh Ambule Date: Sun, 12 Jul 2020 22:18:02 +0530 Subject: [PATCH 27/28] Fix cpplint --- src/runtime/contrib/onnx/onnx_module.cc | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/runtime/contrib/onnx/onnx_module.cc b/src/runtime/contrib/onnx/onnx_module.cc index 67b379d91991..9574b8674c8b 100644 --- a/src/runtime/contrib/onnx/onnx_module.cc +++ b/src/runtime/contrib/onnx/onnx_module.cc @@ -32,7 +32,7 @@ using namespace tvm::runtime; class ONNXSourceModuleNode : public runtime::ModuleNode { public: explicit ONNXSourceModuleNode(const std::string& code, const std::string& symbol, - const Array& const_vars) + const Array& const_vars) : code_(code), symbol_(symbol), const_vars_(const_vars) {} const char* type_key() const { return "onnx"; } @@ -44,8 +44,8 @@ class ONNXSourceModuleNode : public runtime::ModuleNode { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->const_vars_; }); } else { - LOG(FATAL) << "ONNX Source module cannot execute, to get executable module" - << " build TVM with 'onnx' runtime support"; + LOG(FATAL) << "ONNX Source module cannot execute, to get executable module" + << " build TVM with 'onnx' runtime support"; return PackedFunc(nullptr); } } @@ -66,8 +66,9 @@ class ONNXSourceModuleNode : public runtime::ModuleNode { }; Module ONNXSourceModuleNodeCreate(const String& code, const String& symbol, - const Array& const_vars) { - auto n = make_object(code.operator std::string(), symbol.operator std::string(), const_vars); + const Array& const_vars) { + auto n = make_object(code.operator std::string(), + symbol.operator std::string(), const_vars); return Module(n); } From 04b038ee89b5d6cce822b81697505b6828f0e98f Mon Sep 17 00:00:00 2001 From: maheshambule Date: Mon, 13 Jul 2020 16:21:08 +0530 Subject: [PATCH 28/28] change error tol values --- tests/python/contrib/test_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/contrib/test_onnx.py b/tests/python/contrib/test_onnx.py index 2f2eb5a9518d..76b6bab59248 100644 --- a/tests/python/contrib/test_onnx.py +++ b/tests/python/contrib/test_onnx.py @@ -244,8 +244,8 @@ def verify_batch_norm(axis=1): gamma = np.random.uniform(size=gamma_shape).astype(dtype) moving_mean = np.random.uniform(size=gamma_shape).astype(dtype) moving_var = np.random.uniform(size=gamma_shape).astype(dtype) - verify_results(func, [x_data, gamma, beta, moving_mean, moving_var], 'test_batch_norm', rtol=1e-3, - atol=1e-3) + verify_results(func, [x_data, gamma, beta, moving_mean, moving_var], 'test_batch_norm', rtol=1e-1, + atol=1e-1) verify_batch_norm(axis=1) verify_batch_norm(axis=3)