diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 470f4197c908..c70f5aba39fe 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -934,7 +934,7 @@ def __init__(self, shape, dtype): self._renames = {} self._num_input = 0 self._num_param = 0 - self._shape = shape + self._shape = shape if shape else {} self._dtype = dtype def from_onnx(self, graph, opset): @@ -966,6 +966,9 @@ def from_onnx(self, graph, opset): if not init_tensor.name.strip(): raise ValueError("Tensor's name is required.") self._params[init_tensor.name] = self._parse_array(init_tensor) + self._nodes[init_tensor.name] = new_var(init_tensor.name, + shape=self._params[init_tensor.name].shape, + dtype=self._params[init_tensor.name].dtype) for i in graph.input: # from onnx v0.2, GraphProto.input has type ValueInfoProto, # and the name is 'i.name' @@ -1179,6 +1182,18 @@ def from_onnx(model, params : dict of str to tvm.NDArray The parameter dict to be used by relay """ + try: + import onnx + if hasattr(onnx.checker, 'check_model'): + # try use onnx's own model checker before converting any model + try: + onnx.checker.check_model(model) + except onnx.onnx_cpp2py_export.checker.ValidationError as e: + import warnings + # the checker is a bit violent about errors, so simply print warnings here + warnings.warn(str(e)) + except ImportError: + pass g = GraphProto(shape, dtype) graph = model.graph try: