Skip to content

Commit

Permalink
[Relay][ONNX] fix #3134 converter where initializers were not registe…
Browse files Browse the repository at this point in the history
…red as nodes (#3143)
  • Loading branch information
zhreshold authored and tqchen committed May 20, 2019
1 parent d4fb0a2 commit 3a9de90
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -934,7 +934,7 @@ def __init__(self, shape, dtype):
self._renames = {}
self._num_input = 0
self._num_param = 0
self._shape = shape
self._shape = shape if shape else {}
self._dtype = dtype

def from_onnx(self, graph, opset):
Expand Down Expand Up @@ -966,6 +966,9 @@ def from_onnx(self, graph, opset):
if not init_tensor.name.strip():
raise ValueError("Tensor's name is required.")
self._params[init_tensor.name] = self._parse_array(init_tensor)
self._nodes[init_tensor.name] = new_var(init_tensor.name,
shape=self._params[init_tensor.name].shape,
dtype=self._params[init_tensor.name].dtype)
for i in graph.input:
# from onnx v0.2, GraphProto.input has type ValueInfoProto,
# and the name is 'i.name'
Expand Down Expand Up @@ -1179,6 +1182,18 @@ def from_onnx(model,
params : dict of str to tvm.NDArray
The parameter dict to be used by relay
"""
try:
import onnx
if hasattr(onnx.checker, 'check_model'):
# try use onnx's own model checker before converting any model
try:
onnx.checker.check_model(model)
except onnx.onnx_cpp2py_export.checker.ValidationError as e:
import warnings
# the checker is a bit violent about errors, so simply print warnings here
warnings.warn(str(e))
except ImportError:
pass
g = GraphProto(shape, dtype)
graph = model.graph
try:
Expand Down

0 comments on commit 3a9de90

Please sign in to comment.