Skip to content

Commit

Permalink
[Frontend][Torch] Fix up graph input handling (#5204)
Browse files Browse the repository at this point in the history
* [Frontend][Torch] Simplify operator input handling

* [Frontend][Torch] Allow user supplied input names to override graph inputs

* Fix pylint issues

* Updates from code review feedback

* Fix tutorial to use shape list input

* Disable intermittent test failure in topi vision test
  • Loading branch information
jjohnson-arm authored Apr 2, 2020
1 parent 15b1751 commit 03cbf78
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 94 deletions.
155 changes: 80 additions & 75 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,16 +1071,8 @@ def _get_input_names(node_or_graph):
return [inp.debugName() for inp in node_or_graph.inputs()]


def _get_op_inputs(op_node, outputs, output_index_map):
input_names = [output_index_map[name]
for name in _get_input_names(op_node)]
return [outputs[name] for name in input_names]


def _update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):
for output_name, output in name_output_pairs:
output_index_map[output_name] = len(outputs)
outputs.append(output)
def _get_op_inputs(op_node, outputs):
return [outputs[name] for name in _get_input_names(op_node)]


def _report_missing_conversion(op_names):
Expand All @@ -1100,18 +1092,31 @@ def _report_missing_conversion(op_names):
raise NotImplementedError(msg)


def _check_input_names(script_module, input_shapes):
""" Check the graph inputs match the inputs """
ir_inputs = get_graph_input_names(script_module)

for ir_input in ir_inputs:
if ir_input not in input_shapes:
msg = "Missing graph input {} in input_shapes".format(ir_input)
raise RuntimeError(msg)

for input_name in input_shapes:
if input_name not in ir_inputs:
msg = "Unused graph input {} in input_shapes".format(input_name)
def _check_inputs(graph, input_shapes):
"""
Check the graph inputs match the expected number of inputs
and are in the correct format
"""
ir_inputs = _get_graph_input_names(graph)

if not isinstance(input_shapes, list):
msg = "Graph inputs input_shapes should be list"
raise RuntimeError(msg)
missing_inputs = len(ir_inputs) - len(input_shapes)
if missing_inputs > 0:
msg = "Missing {} graph input(s) in input_shapes".format(missing_inputs)
raise RuntimeError(msg)

for num, inp in enumerate(input_shapes):
if num < len(ir_inputs):
if not isinstance(inp, tuple):
msg = "Graph input {} is not a tuple".format(num)
raise RuntimeError(msg)
if (len(inp) != 2 or not isinstance(inp[0], str)):
msg = "Graph input {} is not valid, expected ('name', shape)".format(inp)
raise RuntimeError(msg)
else:
msg = "Unused graph input {} in input_shapes".format(inp)
logging.warning(msg)


Expand Down Expand Up @@ -1203,10 +1208,19 @@ def _get_operator_nodes(nodes):
return ops


def _get_relay_input_vars(input_shapes):
""" Return Relay vars from input shapes """
return {iname: _expr.var(iname, shape=ishape)
for iname, ishape in input_shapes.items()}
def _get_relay_input_vars(graph, input_shapes):
"""
Return Relay vars from input shapes and create entries based on
expected graph inputs - to allow translation
"""
input_vars = {}
ir_inputs = _get_graph_input_names(graph)
for ir_input, (name, shape) in zip(ir_inputs, input_shapes):
inp = _expr.var(name, shape=shape)
# Translate from graph input to user input name
input_vars[ir_input] = inp

return input_vars


def get_use_chains(root_node, terminate=lambda _: False):
Expand Down Expand Up @@ -1284,33 +1298,33 @@ def convert_params(graph, state_dict):
return params, param_tensors, packed_param_map


def convert_block(block, outputs, output_index_map):
def convert_block(block, outputs):
""" Translate Torch "Block", used for prim::If and prim::Loop """
ops = _get_operator_nodes(block.nodes())
ret_names = _get_input_names(block.returnNode())
return convert_operators(ops, outputs, output_index_map, ret_names)
return convert_operators(ops, outputs, ret_names)


def convert_if(if_node, outputs, output_index_map):
def convert_if(if_node, outputs):
""" Translate Torch prim::If to Relay If """
cond = outputs[output_index_map[if_node.inputsAt(0).debugName()]]
cond = outputs[if_node.inputsAt(0).debugName()]
blocks = list(if_node.blocks())
true_branch = convert_block(blocks[0], outputs, output_index_map)
false_branch = convert_block(blocks[1], outputs, output_index_map)
true_branch = convert_block(blocks[0], outputs)
false_branch = convert_block(blocks[1], outputs)
assert len(true_branch) == 1 and len(false_branch) == 1
return _expr.If(cond, true_branch[0], false_branch[0])


def convert_loop(loop_node, outputs, output_index_map):
def convert_loop(loop_node, outputs):
""" Translate Torch prim::Loop to Relay while_loop """
def get_input(index):
ivalue = loop_node.inputsAt(index)
inode = ivalue.node()
if inode.kind() == "prim::Constant":
return _expr.const(_get_constant(inode))
var_name = ivalue.debugName()
assert var_name in output_index_map
return _wrap_const(outputs[output_index_map[var_name]])
assert var_name in outputs
return _wrap_const(outputs[var_name])

# Refer to the spec for prim::Loop below
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/OVERVIEW.md#loops
Expand Down Expand Up @@ -1342,9 +1356,9 @@ def body(*current_vals):
# Update loop variables using the prev iteration outputs
assert len(current_vals) == len(block_input_names)
for (i, iname) in enumerate(block_input_names):
outputs[output_index_map[iname]] = current_vals[i]
outputs[iname] = current_vals[i]

block_outputs = convert_block(body_block, outputs, output_index_map)
block_outputs = convert_block(body_block, outputs)

if not is_while_loop:
# iter var increment implicit in torch, so do it manually
Expand Down Expand Up @@ -1374,7 +1388,7 @@ def get_var(name, val):

name_val_pairs = list(zip(block_input_names,
[init_loop_iter_val] + init_vals))
_update_outputs_from_pairs(name_val_pairs, outputs, output_index_map)
outputs.update(name_val_pairs)

loop_iter_var = _expr.var(block_input_names[0], shape=(),
dtype=loop_iter_dtype)
Expand All @@ -1386,36 +1400,30 @@ def get_var(name, val):
return [_expr.TupleGetItem(loop_val, i+1) for i in range(num_loop_var)]


def convert_operators(operators, outputs, output_index_map, ret_names):
def convert_operators(operators, outputs, ret_names):
""" Convert each Torch IR operators to Relay equivalent """
for node_name, op_node in operators:
operator = op_node.kind()
inputs = _get_op_inputs(op_node, outputs, output_index_map)
inputs = _get_op_inputs(op_node, outputs)

if operator == "prim::Constant":
output_index_map[node_name] = len(outputs)
outputs.append(_get_constant(op_node))
outputs[node_name] = _get_constant(op_node)
elif operator == 'prim::ListConstruct' and _is_int_seq(inputs):
output_index_map[node_name] = len(outputs)
outputs.append(_expr.var(node_name, shape=inputs))
outputs[node_name] = _expr.var(node_name, shape=inputs)
elif operator in ['prim::ListConstruct', 'prim::TupleConstruct']:
output_index_map[node_name] = len(outputs)
outputs.append(inputs)
outputs[node_name] = inputs
elif operator in ["prim::ListUnpack", 'prim::TupleUnpack']:
assert len(inputs) == 1
unpacked_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(unpacked_names, inputs[0]),
outputs, output_index_map)
outputs.update(zip(unpacked_names, inputs[0]))
elif operator == "prim::If":
if_out = convert_if(op_node, outputs, output_index_map)
output_index_map[node_name] = len(outputs)
outputs.append(if_out)
if_out = convert_if(op_node, outputs)
outputs[node_name] = if_out
elif operator == "prim::Loop":
loop_out = convert_loop(op_node, outputs, output_index_map)
loop_out = convert_loop(op_node, outputs)
unpacked_names = _get_output_names(op_node)
assert len(loop_out) == len(unpacked_names)
_update_outputs_from_pairs(zip(unpacked_names, loop_out),
outputs, output_index_map)
outputs.update(zip(unpacked_names, loop_out))
else:
relay_op = _convert_map[operator]
relay_out = relay_op(inputs, _get_input_types(op_node))
Expand All @@ -1424,13 +1432,11 @@ def convert_operators(operators, outputs, output_index_map, ret_names):
# This is for torch operators that return multiple outputs
# See _adaptive_max_2d above for example
out_names = _get_output_names(op_node)
_update_outputs_from_pairs(zip(out_names, relay_out),
outputs, output_index_map)
outputs.update(zip(out_names, relay_out))
else:
output_index_map[node_name] = len(outputs)
outputs.append(relay_out)
outputs[node_name] = relay_out

return [_wrap_const(outputs[output_index_map[ret_name]])
return [_wrap_const(outputs[ret_name])
for ret_name in ret_names]


Expand All @@ -1446,11 +1452,11 @@ def get_all_op_names(graph):
return set(node.kind() for node in nodes)


def get_graph_input_names(script_module):
""" Use this function to set the keys for input_shapes"""
# It seems variable names could change the first time a copy is made
# Use the copy of the graph here to prevent troubles later
ir_inputs = _get_input_names(script_module.graph.copy())
def _get_graph_input_names(graph):
""" Get the graph input names (use after graph copy and run jit passes) """
# Variable names could change the first time a copy is made and after
# _run_jit_passes is called, expected that those functions already invoked
ir_inputs = _get_input_names(graph)
return ir_inputs[1:] # remove self at the 0th arg


Expand All @@ -1464,9 +1470,10 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):
TorchScripted PyTorch graph
Note: We currently only support traces (ie: torch.jit.trace(model, input))
input_shapes : Dictionary of input dimensions
Graph level input shape dictionary
The keys should be the same one returned by get_graph_input_names(...) above
input_shapes : List of tuples of input name and input dimensions
Graph level input shape list
The same input names need to be used for deployment, so choose easy to
remember names (such as: input0, input1)
custom_convert_map: Dictionary of str to Relay op
A custom op conversion map in the same format as _convert_map above
Expand All @@ -1487,30 +1494,28 @@ def from_pytorch(script_module, input_shapes, custom_convert_map=None):

op_names = get_all_op_names(graph)
_report_missing_conversion(op_names)
_check_input_names(script_module, input_shapes)
_check_inputs(graph, input_shapes)

params = script_module.state_dict()
input_vars = _get_relay_input_vars(input_shapes)
outputs = _get_relay_input_vars(graph, input_shapes)
param_vars, tensors, packed_param_map = convert_params(graph, params)
tvm_params = {k: tvm.nd.array(v) for k, v in tensors.items()}

input_vars.update(param_vars)
outputs = list(input_vars.values())
output_index_map = dict(zip(input_vars.keys(), range(len(outputs))))
outputs.update(param_vars)
ret_name = _get_input_names(graph.return_node())

# For quantized models
if "aten::quantize_per_tensor" in op_names:
weight_quant_params = qnn_torch.get_weight_quant_params(script_module)
qnn_torch.add_input_quant_params_to_op_inputs(graph)
qnn_torch.add_quant_params_to_outputs(outputs, output_index_map,
qnn_torch.add_quant_params_to_outputs(outputs,
packed_param_map,
weight_quant_params)
qnn_torch.add_quant_params(tvm_params, weight_quant_params)
_convert_map.update(qnn_torch.convert_map)

ret = convert_operators(_get_operator_nodes(graph.nodes()), outputs,
output_index_map, ret_name)
ret = convert_operators(_get_operator_nodes(graph.nodes()),
outputs, ret_name)

if isinstance(ret[0], list):
ret[0] = _expr.Tuple(ret[0])
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/relay/frontend/qnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,19 @@ def get_weight_quant_params(script_module):
return quant_params


def add_quant_params_to_outputs(outputs, output_index_map,
packed_param_map, quant_params):
def add_quant_params_to_outputs(outputs, packed_param_map,
quant_params):
"""
Add quant params to outputs so that they can be referenced by other
ops later. Weights are quantized here.
"""
for node_name, packed_param_name in packed_param_map.items():
qparam = quant_params[packed_param_name]
output_index_map[node_name] = len(outputs)
qweight = relay.qnn.op.quantize(qparam.weight_var, qparam.scale,
qparam.zero_point, out_dtype="int8",
axis=0)
param_tup = (qweight, qparam.scale, qparam.zero_point, qparam.bias_var)
outputs.append(param_tup)
outputs[node_name] = param_tup


def _get_quant_param_for_input(input_value):
Expand Down
7 changes: 3 additions & 4 deletions tests/python/frontend/pytorch/qnn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@

import tvm
from tvm import relay
from tvm.relay.frontend.pytorch import get_graph_input_names
from tvm.contrib.download import download_testdata


Expand All @@ -39,7 +38,7 @@ def torch_version_check():

def get_tvm_runtime(script_module, input_name, ishape):

input_shapes = {input_name: ishape}
input_shapes = [(input_name, ishape)]
mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

with relay.build_config(opt_level=3):
Expand Down Expand Up @@ -287,7 +286,7 @@ def test_quantized_modules():
with torch.no_grad():
pt_result = script_module(inp.clone()).numpy()

input_name = get_graph_input_names(script_module)[0]
input_name = "input"
runtime = get_tvm_runtime(script_module, input_name, ishape)
runtime.set_input(input_name, inp.numpy().copy())
runtime.run()
Expand Down Expand Up @@ -383,7 +382,7 @@ def get_imagenet_input():
with torch.no_grad():
pt_result = script_module(pt_inp).numpy()

input_name = get_graph_input_names(script_module)[0]
input_name = "image"
runtime = get_tvm_runtime(script_module, input_name, (1, 3, 224, 224))
runtime.set_input(input_name, inp)
runtime.run()
Expand Down
14 changes: 7 additions & 7 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from tvm import relay
from tvm.contrib import graph_runtime
from tvm.relay.testing.config import ctx_list
from tvm.relay.frontend.pytorch import get_graph_input_names


sys.setrecursionlimit(10000)
Expand Down Expand Up @@ -169,8 +168,8 @@ def verify_model(model_name, input_data=[],
else:
trace = trace.cpu()

input_names = get_graph_input_names(trace)
input_shapes = dict(zip(input_names,
input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)]
input_shapes = list(zip(input_names,
[inp.shape for inp in baseline_input]))
mod, params = relay.frontend.from_pytorch(trace, input_shapes,
custom_convert_map)
Expand Down Expand Up @@ -888,11 +887,12 @@ def test_3d_models():

def verify_script_model(pt_model, ishapes):
script_module = torch.jit.script(pt_model)
input_names = get_graph_input_names(script_module)
input_shapes = dict(zip(input_names, ishapes))

inputs = [torch.randn(input_shapes[input_name], dtype=torch.float)
for input_name in input_names]
input_names = ["i{}".format(idx) for idx, ish in enumerate(ishapes)]
input_shapes = list(zip(input_names, ishapes))

inputs = [torch.randn(shape, dtype=torch.float)
for shape in ishapes]

mod, params = relay.frontend.from_pytorch(script_module, input_shapes)

Expand Down
3 changes: 3 additions & 0 deletions topi/tests/python/test_topi_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,14 @@ def check_device(device):
tvm.testing.assert_allclose(tvm_out1.asnumpy(), np_out1, rtol=1e-3)
tvm.testing.assert_allclose(tvm_out2.asnumpy(), np_out2, rtol=1e-3)

""" Skip this test as it is intermittent
see https://github.com/apache/incubator-tvm/pull/4901#issuecomment-595040094
for device in ['llvm', 'cuda', 'opencl']:
# Disable opencl test for now
if device != "llvm" and device != "cuda":
continue
check_device(device)
"""


def test_get_valid_counts():
Expand Down
Loading

0 comments on commit 03cbf78

Please sign in to comment.