Skip to content

Commit

Permalink
Fix bug in ONNX importer (#3084)
Browse files Browse the repository at this point in the history
  • Loading branch information
jroesch authored and tqchen committed Apr 29, 2019
1 parent a706ad1 commit ba6f194
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
5 changes: 4 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,10 @@ def from_onnx(self, graph, opset):
dtype=self._params[i_name].dtype)
else:
self._num_input += 1
tshape = self._shape[i_name] if i_name in self._shape else ()
if i_name in self._shape:
tshape = self._shape[i_name]
else:
raise ValueError("Must provide an input shape for `{0}`.".format(i_name))
if isinstance(self._dtype, dict):
dtype = self._dtype[i_name] if i_name in self._dtype else d_type
else:
Expand Down
9 changes: 7 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,10 +724,15 @@ def verify_constantfill(is_shape, input_dim, out_dim, value, dtype, **kwargs):
else:
fill_node = helper.make_node("ConstantFill", ["input_a"], ["out"], value=value, dtype=dtype, **kwargs)

if is_shape == True:
inputs = []
else:
inputs = [helper.make_tensor_value_info("input_a",
TensorProto.FLOAT, list(input_dim))]

graph = helper.make_graph([fill_node],
"fill_test",
inputs = [helper.make_tensor_value_info("input_a",
TensorProto.FLOAT, list(input_dim))],
inputs,
outputs = [helper.make_tensor_value_info("out",
TensorProto.FLOAT, list(out.shape))])

Expand Down

0 comments on commit ba6f194

Please sign in to comment.