Skip to content

Commit

Permalink
[Frontend][Torch] Check graph inputs match expected (apache#4992)
Browse files Browse the repository at this point in the history
* [Frontend][Torch] Check graph inputs match expected

* error/warn when missing/unused graph inputs

* Change to use get_graph_input_names
  • Loading branch information
jjohnson-arm authored and Trevor Morris committed Apr 16, 2020
1 parent a13f42d commit 1a6fcb6
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,20 @@ def _report_missing_conversion(op_names):
msg = "The following operators are not implemented: {}".format(missing)
raise NotImplementedError(msg)

def _check_input_names(script_module, input_shapes):
""" Check the graph inputs match the inputs """
ir_inputs = get_graph_input_names(script_module)

for ir_input in ir_inputs:
if ir_input not in input_shapes:
msg = "Missing graph input {} in input_shapes".format(ir_input)
raise RuntimeError(msg)

for input_name in input_shapes:
if input_name not in ir_inputs:
msg = "Unused graph input {} in input_shapes".format(input_name)
logging.warning(msg)


def _getattr_attr_name(node):
attribute_names = node.attributeNames()
Expand Down Expand Up @@ -1150,6 +1164,7 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):

op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
_check_input_names(script_module, input_shapes)

params = script_module.state_dict()
input_vars = parse_inputs(graph.inputs(), input_shapes)
Expand Down

0 comments on commit 1a6fcb6

Please sign in to comment.