From 74c125625e333610e5dfca5fdf78d166bbddefad Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Wed, 22 May 2019 18:34:16 -0700 Subject: [PATCH 1/9] Fix reshape precompute, and type error --- python/tvm/relay/frontend/onnx.py | 36 ++++++++++++---------- tests/python/frontend/onnx/test_forward.py | 14 +++++++++ 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 18253e498560..0a7b2ed9f85e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -408,21 +408,20 @@ def _impl_v1(cls, inputs, attr, params): shape = tuple(params[inputs[1].name_hint].asnumpy()) out = _op.reshape(inputs[0], shape) else: - # Try to infer shape by precompute prune if possible. - # TODO: good to check inputs to be in params. - # to be enhanced when relay support list_input_names API of NNVM - logging.warning("Infering Reshape argument by precompute") - func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) + data, shape = inputs + logging.warning("Constant evaluating Reshape's shape argument, may reduce performance") + shape_params = ir_pass.free_vars(shape) + func = _expr.Function(shape_params, shape) with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - params_new = m.get_output(0) - inputs.pop(1) - out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten())) + ex = tvm.relay.create_executor("debug") + inputs = [] + for sp in shape_params: + if not sp.name_hint in params: + sh = [int(i) for i in sp.type_annotation.shape] + inputs.append( + tvm.nd.array(np.random.rand(*sh).astype('float32'))) + static_shape = ex.evaluate(func)(*inputs, **params) + out = _op.reshape(data, newshape=tuple(static_shape.asnumpy())) return out @@ -567,6 +566,7 @@ class Shape(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + # TODO(@jroesch): use shape_of once it has been fixed return _op.shape_of(inputs[0]) class Cast(OnnxOpConverter): @@ -1056,8 +1056,12 @@ def from_onnx(self, graph, opset): if op_name == "Constant": t_proto = self._parse_attr(node.attribute)["value"] self._num_param += 1 - self._params[node.output[0]] = self._parse_array(t_proto) - self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims)) + # We should convert scalar integers to int32, to normalize. + array = self._parse_array(t_proto) + if len(array.shape) == 0 and array.dtype == 'int64': + array = _nd.array(array.asnumpy().astype('int32')) + self._params[node.output[0]] = array + self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims), dtype=array.dtype) else: if op_name == "ConstantFill": fill_value = attr.get('value', 0.0) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 095f1feb246a..fd67dfd933ef 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -14,8 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import attr import numpy as np import math +import mxnet as mx +import torch +import torchvision import topi import topi.testing import tvm @@ -1070,6 +1074,14 @@ def test_LogSoftmax(): 'LogSoftmax', {'axis': 1}) +def test_resnet18_pytorch(): + dummy_input = torch.randn(1,3,224,224) + model = torchvision.models.resnet18() + onnx_model = torch.onnx.export(model, dummy_input, 'resnet-18.onnx', export_params=True, verbose=True) + model = onnx.load('resnet-18.onnx') + shapes = { '0' : (1, 3, 224, 224) } + expr, params = relay.frontend.from_onnx(model, shape=shapes) + if __name__ == '__main__': test_flatten() test_reshape() @@ -1109,3 +1121,5 @@ def test_LogSoftmax(): test_ParametricSoftplus() test_Scale() test_LogSoftmax() + test_resnet18() + test_resnet18_pytorch() From 310d4d59ebdc9f24122a7443618adbeccc9981a9 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 May 2019 01:35:43 -0700 Subject: [PATCH 2/9] Add more models, it seems like torch's ONNX export is limited --- tests/python/frontend/onnx/test_forward.py | 129 +++++++++++++-------- 1 file changed, 81 insertions(+), 48 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index fd67dfd933ef..5ed7681dc082 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -17,7 +17,6 @@ import attr import numpy as np import math -import mxnet as mx import torch import torchvision import topi @@ -1074,52 +1073,86 @@ def test_LogSoftmax(): 'LogSoftmax', {'axis': 1}) -def test_resnet18_pytorch(): - dummy_input = torch.randn(1,3,224,224) - model = torchvision.models.resnet18() - onnx_model = torch.onnx.export(model, dummy_input, 'resnet-18.onnx', export_params=True, verbose=True) - model = onnx.load('resnet-18.onnx') - shapes = { '0' : (1, 3, 224, 224) } - expr, params = relay.frontend.from_onnx(model, shape=shapes) +def check_torch_conversion(model, input_size): + dummy_input = torch.randn(*input_size) + file_name = '{}.onnx'.format(model.__name__) + torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=True) + onnx_model = onnx.load(file_name) + shapes = { '0' : input_size } + expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes) + +def test_resnet(): + check_torch_conversion(torchvision.models.resnet18, (1,3,224,224)) + # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224)) + +# def test_alexnet(): + # Torch's ONNX export does not support the adaptive pooling used by AlexNet? + # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224)) + +# Torch's ONNX export does not support the adaptive pooling used by vgg16? +# def test_vgg16(): +# check_torch_conversion(torchvision.models.vgg16, (1,3,224,224)) + +def test_squeezenet(): + check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) + +def test_densenet(): + check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) + +def test_inception(): + check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224)) + +def test_googlenet(): + check_torch_conversion(torchvision.models.googlenet, (1,3,224,224)) + +def test_shufflenetv2(): + check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) + if __name__ == '__main__': - test_flatten() - test_reshape() - test_shape() - test_power() - test_squeeze() - test_unsqueeze() - test_slice() - test_floor() - test_ceil() - test_clip() - test_matmul() - test_gather() - test_lrn() - test_upsample() - test_forward_min() - test_forward_max() - test_forward_mean() - test_forward_hardsigmoid() - test_forward_arg_min_max() - test_softmax() - test_constantfill() - test_pad() - test_reduce_max() - test_reduce_min() - test_reduce_sum() - test_reduce_mean() - test_pad() - test_split() - test_binary_ops() - test_single_ops() - test_leaky_relu() - test_elu() - test_selu() - test_ThresholdedRelu() - test_ScaledTanh() - test_ParametricSoftplus() - test_Scale() - test_LogSoftmax() - test_resnet18() - test_resnet18_pytorch() + # test_flatten() + # test_reshape() + # test_shape() + # test_power() + # test_squeeze() + # test_unsqueeze() + # test_slice() + # test_floor() + # test_ceil() + # test_clip() + # test_matmul() + # test_gather() + # test_lrn() + # test_upsample() + # test_forward_min() + # test_forward_max() + # test_forward_mean() + # test_forward_hardsigmoid() + # test_forward_arg_min_max() + # test_softmax() + # test_constantfill() + # test_pad() + # test_reduce_max() + # test_reduce_min() + # test_reduce_sum() + # test_reduce_mean() + # test_pad() + # test_split() + # test_binary_ops() + # test_single_ops() + # test_leaky_relu() + # test_elu() + # test_selu() + # test_ThresholdedRelu() + # test_ScaledTanh() + # test_ParametricSoftplus() + # test_Scale() + # test_LogSoftmax() + test_resnet() + # test_alexnet() + # test_vgg16() + test_squeezenet() + test_densenet() + test_inception() + test_googlenet() + test_shufflenetv2() From dc15a4b4207bec0f2cea7869fafea6e4e730a763 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 May 2019 01:37:17 -0700 Subject: [PATCH 3/9] Restore existing tests --- tests/python/frontend/onnx/test_forward.py | 80 +++++++++++----------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 5ed7681dc082..8ed387ac5559 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1110,47 +1110,47 @@ def test_shufflenetv2(): if __name__ == '__main__': - # test_flatten() - # test_reshape() - # test_shape() - # test_power() - # test_squeeze() - # test_unsqueeze() - # test_slice() - # test_floor() - # test_ceil() - # test_clip() - # test_matmul() - # test_gather() - # test_lrn() - # test_upsample() - # test_forward_min() - # test_forward_max() - # test_forward_mean() - # test_forward_hardsigmoid() - # test_forward_arg_min_max() - # test_softmax() - # test_constantfill() - # test_pad() - # test_reduce_max() - # test_reduce_min() - # test_reduce_sum() - # test_reduce_mean() - # test_pad() - # test_split() - # test_binary_ops() - # test_single_ops() - # test_leaky_relu() - # test_elu() - # test_selu() - # test_ThresholdedRelu() - # test_ScaledTanh() - # test_ParametricSoftplus() - # test_Scale() - # test_LogSoftmax() + test_flatten() + test_reshape() + test_shape() + test_power() + test_squeeze() + test_unsqueeze() + test_slice() + test_floor() + test_ceil() + test_clip() + test_matmul() + test_gather() + test_lrn() + test_upsample() + test_forward_min() + test_forward_max() + test_forward_mean() + test_forward_hardsigmoid() + test_forward_arg_min_max() + test_softmax() + test_constantfill() + test_pad() + test_reduce_max() + test_reduce_min() + test_reduce_sum() + test_reduce_mean() + test_pad() + test_split() + test_binary_ops() + test_single_ops() + test_leaky_relu() + test_elu() + test_selu() + test_ThresholdedRelu() + test_ScaledTanh() + test_ParametricSoftplus() + test_Scale() + test_LogSoftmax() test_resnet() - # test_alexnet() - # test_vgg16() + test_alexnet() + test_vgg16() test_squeezenet() test_densenet() test_inception() From 74cd3800b0d3e40d2eb41d3927c07cac5da9eea7 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Thu, 23 May 2019 16:23:11 -0700 Subject: [PATCH 4/9] Wrap line --- python/tvm/relay/frontend/onnx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 0a7b2ed9f85e..997d774431a5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1061,7 +1061,10 @@ def from_onnx(self, graph, opset): if len(array.shape) == 0 and array.dtype == 'int64': array = _nd.array(array.asnumpy().astype('int32')) self._params[node.output[0]] = array - self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims), dtype=array.dtype) + self._nodes[node.output[0]] = new_var( + node.output[0], + shape=list(t_proto.dims), + dtype=array.dtype) else: if op_name == "ConstantFill": fill_value = attr.get('value', 0.0) From 761aeb8134cc1c8f9befbadc7d058de9d0870a91 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Mon, 27 May 2019 21:35:42 -0700 Subject: [PATCH 5/9] Disable models which do not work --- tests/python/frontend/onnx/test_forward.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 8ed387ac5559..1caecc41d76f 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1105,8 +1105,9 @@ def test_inception(): def test_googlenet(): check_torch_conversion(torchvision.models.googlenet, (1,3,224,224)) -def test_shufflenetv2(): - check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) +# TODO(@jroesch): Update Torch + ONNX to support this import. +# def test_shufflenetv2(): +# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) if __name__ == '__main__': @@ -1149,10 +1150,10 @@ def test_shufflenetv2(): test_Scale() test_LogSoftmax() test_resnet() - test_alexnet() - test_vgg16() + # test_alexnet() + # test_vgg16() test_squeezenet() test_densenet() test_inception() test_googlenet() - test_shufflenetv2() + # test_shufflenetv2() From 356e14c2a8d01064ca01c10c7785a6b2f8edaa61 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Jun 2019 14:54:27 -0700 Subject: [PATCH 6/9] More tests don't work due to PyTorch --- tests/python/frontend/onnx/test_forward.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 1caecc41d76f..543897f52f15 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1076,7 +1076,8 @@ def test_LogSoftmax(): def check_torch_conversion(model, input_size): dummy_input = torch.randn(*input_size) file_name = '{}.onnx'.format(model.__name__) - torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=True) + # Set verbose=True for more output + torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) onnx_model = onnx.load(file_name) shapes = { '0' : input_size } expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes) @@ -1093,9 +1094,12 @@ def test_resnet(): # def test_vgg16(): # check_torch_conversion(torchvision.models.vgg16, (1,3,224,224)) -def test_squeezenet(): - check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) +# TODO(@jroesch): Update Torch + ONNX to support this import. +# def test_squeezenet(): +# # Torch's ONNX export does not support the max pooling used by Squezenet +# check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) +# TODO(@jroesch): Update Torch + ONNX to support this import. def test_densenet(): check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) @@ -1150,10 +1154,6 @@ def test_googlenet(): test_Scale() test_LogSoftmax() test_resnet() - # test_alexnet() - # test_vgg16() - test_squeezenet() test_densenet() test_inception() test_googlenet() - # test_shufflenetv2() From 141e0a29fc3afbbd7f159528d359ebd09e7af813 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Jun 2019 17:01:33 -0700 Subject: [PATCH 7/9] Remove --- tests/python/frontend/onnx/test_forward.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 543897f52f15..269c3460c166 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1100,8 +1100,8 @@ def test_resnet(): # check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) # TODO(@jroesch): Update Torch + ONNX to support this import. -def test_densenet(): - check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) +# def test_densenet(): +# check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) def test_inception(): check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224)) @@ -1154,6 +1154,5 @@ def test_googlenet(): test_Scale() test_LogSoftmax() test_resnet() - test_densenet() test_inception() test_googlenet() From 704d1ab75e62b88d0078cfc94a043d908b616165 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Fri, 14 Jun 2019 17:13:01 -0700 Subject: [PATCH 8/9] Improve performance --- python/tvm/relay/frontend/onnx.py | 4 ++++ tests/python/frontend/onnx/test_forward.py | 5 +++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 997d774431a5..78bcd8b0871c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -412,6 +412,10 @@ def _impl_v1(cls, inputs, attr, params): logging.warning("Constant evaluating Reshape's shape argument, may reduce performance") shape_params = ir_pass.free_vars(shape) func = _expr.Function(shape_params, shape) + func = ir_pass.infer_type(func) + func = ir_pass.fold_constant(func) + shape_params = ir_pass.free_vars(func.body) + func = _expr.Function(shape_params, func.body) with tvm.relay.build_config(opt_level=0): ex = tvm.relay.create_executor("debug") inputs = [] diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 269c3460c166..d4e577024eb0 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1100,8 +1100,8 @@ def test_resnet(): # check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) # TODO(@jroesch): Update Torch + ONNX to support this import. -# def test_densenet(): -# check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) +def test_densenet(): + check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) def test_inception(): check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224)) @@ -1156,3 +1156,4 @@ def test_googlenet(): test_resnet() test_inception() test_googlenet() + test_densenet() From adfdeb34db0ed033d6d1223a7189f79388dbcac5 Mon Sep 17 00:00:00 2001 From: Jared Roesch Date: Sat, 15 Jun 2019 13:16:08 -0700 Subject: [PATCH 9/9] Diable lenet for now --- tests/python/frontend/onnx/test_forward.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d4e577024eb0..60d86589ae49 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1099,15 +1099,15 @@ def test_resnet(): # # Torch's ONNX export does not support the max pooling used by Squezenet # check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) -# TODO(@jroesch): Update Torch + ONNX to support this import. def test_densenet(): check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) def test_inception(): check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224)) -def test_googlenet(): - check_torch_conversion(torchvision.models.googlenet, (1,3,224,224)) +# TODO(@jroesch): Update Torch + ONNX to support this import. +# def test_googlenet(): +# check_torch_conversion(torchvision.models.googlenet, (1,3,224,224)) # TODO(@jroesch): Update Torch + ONNX to support this import. # def test_shufflenetv2(): @@ -1155,5 +1155,4 @@ def test_googlenet(): test_LogSoftmax() test_resnet() test_inception() - test_googlenet() test_densenet()