diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 18253e498560..78bcd8b0871c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -408,21 +408,24 @@ def _impl_v1(cls, inputs, attr, params): shape = tuple(params[inputs[1].name_hint].asnumpy()) out = _op.reshape(inputs[0], shape) else: - # Try to infer shape by precompute prune if possible. - # TODO: good to check inputs to be in params. - # to be enhanced when relay support list_input_names API of NNVM - logging.warning("Infering Reshape argument by precompute") - func = _expr.Function(ir_pass.free_vars(inputs[1]), inputs[1]) + data, shape = inputs + logging.warning("Constant evaluating Reshape's shape argument, may reduce performance") + shape_params = ir_pass.free_vars(shape) + func = _expr.Function(shape_params, shape) + func = ir_pass.infer_type(func) + func = ir_pass.fold_constant(func) + shape_params = ir_pass.free_vars(func.body) + func = _expr.Function(shape_params, func.body) with tvm.relay.build_config(opt_level=0): - graph, lib, params = tvm.relay.build(func, target="llvm", params=params) - ctx = tvm.context("llvm", 0) - from tvm.contrib import graph_runtime - m = graph_runtime.create(graph, lib, ctx) - m.set_input(**params) - m.run() - params_new = m.get_output(0) - inputs.pop(1) - out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten())) + ex = tvm.relay.create_executor("debug") + inputs = [] + for sp in shape_params: + if not sp.name_hint in params: + sh = [int(i) for i in sp.type_annotation.shape] + inputs.append( + tvm.nd.array(np.random.rand(*sh).astype('float32'))) + static_shape = ex.evaluate(func)(*inputs, **params) + out = _op.reshape(data, newshape=tuple(static_shape.asnumpy())) return out @@ -567,6 +570,7 @@ class Shape(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): + # TODO(@jroesch): use shape_of once it has been fixed return _op.shape_of(inputs[0]) class Cast(OnnxOpConverter): @@ -1056,8 +1060,15 @@ def from_onnx(self, graph, opset): if op_name == "Constant": t_proto = self._parse_attr(node.attribute)["value"] self._num_param += 1 - self._params[node.output[0]] = self._parse_array(t_proto) - self._nodes[node.output[0]] = new_var(node.output[0], shape=list(t_proto.dims)) + # We should convert scalar integers to int32, to normalize. + array = self._parse_array(t_proto) + if len(array.shape) == 0 and array.dtype == 'int64': + array = _nd.array(array.asnumpy().astype('int32')) + self._params[node.output[0]] = array + self._nodes[node.output[0]] = new_var( + node.output[0], + shape=list(t_proto.dims), + dtype=array.dtype) else: if op_name == "ConstantFill": fill_value = attr.get('value', 0.0) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 095f1feb246a..60d86589ae49 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -14,8 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import attr import numpy as np import math +import torch +import torchvision import topi import topi.testing import tvm @@ -1070,6 +1073,47 @@ def test_LogSoftmax(): 'LogSoftmax', {'axis': 1}) +def check_torch_conversion(model, input_size): + dummy_input = torch.randn(*input_size) + file_name = '{}.onnx'.format(model.__name__) + # Set verbose=True for more output + torch.onnx.export(model(), dummy_input, file_name, export_params=True, verbose=False) + onnx_model = onnx.load(file_name) + shapes = { '0' : input_size } + expr, params = relay.frontend.from_onnx(onnx_model, shape=shapes) + +def test_resnet(): + check_torch_conversion(torchvision.models.resnet18, (1,3,224,224)) + # check_torch_conversion(torchvision.models.resnet101, (1,3,224,224)) + +# def test_alexnet(): + # Torch's ONNX export does not support the adaptive pooling used by AlexNet? + # check_torch_conversion(torchvision.models.alexnet, (1,3,224,224)) + +# Torch's ONNX export does not support the adaptive pooling used by vgg16? +# def test_vgg16(): +# check_torch_conversion(torchvision.models.vgg16, (1,3,224,224)) + +# TODO(@jroesch): Update Torch + ONNX to support this import. +# def test_squeezenet(): +# # Torch's ONNX export does not support the max pooling used by Squezenet +# check_torch_conversion(torchvision.models.squeezenet1_0, (1,3,224,224)) + +def test_densenet(): + check_torch_conversion(torchvision.models.densenet161, (1,3,224,224)) + +def test_inception(): + check_torch_conversion(torchvision.models.inception_v3, (1,3,224,224)) + +# TODO(@jroesch): Update Torch + ONNX to support this import. +# def test_googlenet(): +# check_torch_conversion(torchvision.models.googlenet, (1,3,224,224)) + +# TODO(@jroesch): Update Torch + ONNX to support this import. +# def test_shufflenetv2(): +# check_torch_conversion(torchvision.models.shufflenetv2, (1,3,224,224)) + + if __name__ == '__main__': test_flatten() test_reshape() @@ -1109,3 +1153,6 @@ def test_LogSoftmax(): test_ParametricSoftplus() test_Scale() test_LogSoftmax() + test_resnet() + test_inception() + test_densenet()