diff --git a/nnvm/python/nnvm/to_relay.py b/nnvm/python/nnvm/to_relay.py index 2a7b27877b686..3c75dc5a3fb98 100644 --- a/nnvm/python/nnvm/to_relay.py +++ b/nnvm/python/nnvm/to_relay.py @@ -480,6 +480,7 @@ def to_relay(graph, shape_dict, dtype_dict, params): "nnvm.to_relay: unsupported operator: {0}".format(op_name)) outputs = [relay_map[nid] for nid in output_ids] + outputs = [x if not isinstance(x, expr.TupleWrapper) else x.astuple() for x in outputs] if len(outputs) == 1: body = outputs[0] else: diff --git a/nnvm/tests/python/compiler/test_to_relay.py b/nnvm/tests/python/compiler/test_to_relay.py index 25037cfd3587e..200c13aebf23c 100644 --- a/nnvm/tests/python/compiler/test_to_relay.py +++ b/nnvm/tests/python/compiler/test_to_relay.py @@ -31,6 +31,23 @@ def check_model(sym, shapes, dtypes, params): relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values())) np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy()) + +def test_split_concatenate(): + shape = (2, 16) + + tensor = nnvm.sym.Variable("tensor", shape=shape) + + splited = nnvm.sym.split(tensor, indices_or_sections=2, axis=1) + + concatenated = nnvm.sym.concatenate(*splited, axis=1) + + shapes = {"tensor": shape} + dtypes = {"tensor": 'float32'} + params = {} + + check_model(concatenated, shapes, dtypes, params) + + # def test_mlp(): # mlp, params = testing.mlp.get_workload(1) # shapes = { "data": (10, 3, 224, 224) }