diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index a28981c417fa..468a7486ca5c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -409,24 +409,21 @@ def _impl_v1(cls, inputs, attr, params): shape = tuple(params[inputs[1].name_hint].asnumpy()) out = _op.reshape(inputs[0], shape) else: - data, shape = inputs - logging.warning("Constant evaluating Reshape's shape argument, may reduce performance") - shape_params = ir_pass.free_vars(shape) - func = _expr.Function(shape_params, shape) - func = ir_pass.infer_type(func) - func = ir_pass.fold_constant(func) - shape_params = ir_pass.free_vars(func.body) - func = _expr.Function(shape_params, func.body) + # Try to infer shape by precompute prune if possible. + # TODO: good to check inputs to be in params. + # to be enhanced when relay support list_input_names API of NNVM + logging.warning("Infering Reshape argument by precompute") + func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) with tvm.relay.build_config(opt_level=0): - ex = tvm.relay.create_executor("debug") - inputs = [] - for sp in shape_params: - if not sp.name_hint in params: - sh = [int(i) for i in sp.type_annotation.shape] - inputs.append( - tvm.nd.array(np.random.rand(*sh).astype('float32'))) - static_shape = ex.evaluate(func)(*inputs, **params) - out = _op.reshape(data, newshape=tuple(static_shape.asnumpy())) + graph, lib, params = tvm.relay.build(func, target="llvm", params=params) + ctx = tvm.context("llvm", 0) + from tvm.contrib import graph_runtime + m = graph_runtime.create(graph, lib, ctx) + m.set_input(**params) + m.run() + params_new = m.get_output(0) + inputs.pop(1) + out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten())) return out @@ -571,7 +568,6 @@ class Shape(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): - # TODO(@jroesch): use shape_of once it has been fixed return _op.shape_of(inputs[0]) class Cast(OnnxOpConverter): @@ -1062,15 +1058,8 @@ def from_onnx(self, graph, opset): if op_name == "Constant": t_proto = self._parse_attr(node.attribute)["value"] self._num_param += 1 - # We should convert scalar integers to int32, to normalize. - array = self._parse_array(t_proto) - if len(array.shape) == 0 and array.dtype == 'int64': - array = _nd.array(array.asnumpy().astype('int32')) - self._params[node.output[0]] = array - self._nodes[node.output[0]] = new_var( - node.output[0], - shape=list(t_proto.dims), - dtype=array.dtype) + self._params[node.output[0]] = self._parse_array(t_proto) + self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims)) else: if op_name == "ConstantFill": fill_value = attr.get('value', 0.0) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index d4c8ee9deeee..7371a88ca677 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -14,11 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import attr import numpy as np import math -import torch -import torchvision import topi import topi.testing import tvm @@ -1075,47 +1072,6 @@ def test_LogSoftmax(): 'LogSoftmax', {'axis': 1}) -def check_torch_conversion(model, input_size): - dummy_input = torch.randn(*input_size) - file_name = '{}.onnx'.format(model.__name__) - # Set verbose=True for more output - torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) - onnx_model = onnx.load(file_name) - shapes = { '0' : input_size } - expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes) - -def test_resnet(): - check_torch_conversion(torchvision.models.resnet18, (1,3,224,224)) - # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224)) - -# def test_alexnet(): - # Torch's ONNX export does not support the adaptive pooling used by AlexNet? - # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224)) - -# Torch's ONNX export does not support the adaptive pooling used by vgg16? -# def test_vgg16(): -# check_torch_conversion(torchvision.models.vgg16, (1,3,224,224)) - -# TODO(@jroesch): Update Torch + ONNX to support this import. -# def test_squeezenet(): -# # Torch's ONNX export does not support the max pooling used by Squezenet -# check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) - -def test_densenet(): - check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) - -def test_inception(): - check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224)) - -# TODO(@jroesch): Update Torch + ONNX to support this import. -# def test_googlenet(): -# check_torch_conversion(torchvision.models.googlenet, (1,3,224,224)) - -# TODO(@jroesch): Update Torch + ONNX to support this import. -# def test_shufflenetv2(): -# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) - - if __name__ == '__main__': test_flatten() test_reshape() @@ -1155,6 +1111,3 @@ def test_inception(): test_ParametricSoftplus() test_Scale() test_LogSoftmax() - test_resnet() - test_inception() - test_densenet()