diff --git a/nnvm/python/nnvm/to_relay.py b/nnvm/python/nnvm/to_relay.py index 7d792116b104..da2f394cb442 100644 --- a/nnvm/python/nnvm/to_relay.py +++ b/nnvm/python/nnvm/to_relay.py @@ -1,6 +1,5 @@ # pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name, unused-argument """Convert an NNVM graph to Relay.""" -import json import numpy from tvm import relay, nd @@ -241,7 +240,7 @@ def _split(children, attrs, odtype='float32'): axis = attrs.get_int('axis', 0) - return op.split(children[0], indices_or_sections, axis).astuple() + return op.split(children[0], indices_or_sections, axis) def _squeeze(children, attrs, odtype='float32'): axis = attrs.get_int_tuple('axis', None) @@ -441,12 +440,10 @@ def to_relay(graph, shape_dict, dtype_dict, params): graph = graph.apply(["InferShape", "InferType"]) shape = graph.json_attr("shape") dtype = [graph_attr.TCODE_TO_DTYPE[di] for di in graph.json_attr("dtype")] - heads = [x[0] for x in json.loads(graph.json())['heads']] gidx = graph.index relay_map = {} fn_params = [] - output_ids = [] for nid, node in enumerate(gidx.nodes): children = [] @@ -468,9 +465,6 @@ def to_relay(graph, shape_dict, dtype_dict, params): fn_params.append(v) relay_map[nid] = v else: - if nid in heads: - output_ids.append(nid) - if op_name in NNVM_OP_2_RELAY_OP: str_attrs = StrAttrsDict(attrs) call = NNVM_OP_2_RELAY_OP[op_name](children, str_attrs, odtype) @@ -479,7 +473,14 @@ def to_relay(graph, shape_dict, dtype_dict, params): raise Exception( "nnvm.to_relay: unsupported operator: {0}".format(op_name)) - outputs = [relay_map[nid] for nid in output_ids] + outputs = [] + for nid, idx, _ in gidx.output_entries: + output = relay_map[nid] + if isinstance(output, expr.TupleWrapper): + outputs.append(output[idx]) + else: + outputs.append(output) + if len(outputs) == 1: body = outputs[0] else: diff --git a/tests/python/frontend/nnvm_to_relay/test_forward.py b/tests/python/frontend/nnvm_to_relay/test_forward.py index 23c0a977d6d9..f32fb803e49b 100644 --- a/tests/python/frontend/nnvm_to_relay/test_forward.py +++ b/tests/python/frontend/nnvm_to_relay/test_forward.py @@ -72,6 +72,23 @@ def test_forward_dqn(): verify_nnvm_to_relay(model, params, data_shape=(1, 4, 84, 84)) +def test_forward_split_concatenate(): + shape = (2, 16) + + tensor = nnvm.sym.Variable("data", shape=shape) + + splited = nnvm.sym.split(tensor, indices_or_sections=2, axis=1) + + concatenated = nnvm.sym.concatenate(*splited, axis=1) + + params = {} + + verify_nnvm_to_relay(splited[0], params, data_shape=shape) + verify_nnvm_to_relay(splited[1], params, data_shape=shape) + verify_nnvm_to_relay(splited, params, data_shape=shape) + verify_nnvm_to_relay(concatenated, params, data_shape=shape) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -80,3 +97,4 @@ def test_forward_dqn(): test_forward_inception_v3() test_forward_densenet() test_forward_dqn() + test_forward_split_concatenate()