Skip to content

Commit

Permalink
Refine fix to handle the case output is a TupleWrapper
Browse files Browse the repository at this point in the history
Add a regression test guarding on original bug.
  • Loading branch information
Li Xiaoquan committed Mar 11, 2019
1 parent ae2d046 commit b6bde89
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
1 change: 1 addition & 0 deletions nnvm/python/nnvm/to_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def to_relay(graph, shape_dict, dtype_dict, params):
"nnvm.to_relay: unsupported operator: {0}".format(op_name))

outputs = [relay_map[nid] for nid in output_ids]
outputs = [x if not isinstance(x, expr.TupleWrapper) else x.astuple() for x in outputs]
if len(outputs) == 1:
body = outputs[0]
else:
Expand Down
17 changes: 17 additions & 0 deletions nnvm/tests/python/compiler/test_to_relay.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ def check_model(sym, shapes, dtypes, params):
relay_out = relay_rts.evaluate(relay_model)(*list(inputs.values()))
np.testing.assert_allclose(nnvm_out.asnumpy(), relay_out.asnumpy())


def test_split_concatenate():
shape = (2, 16)

tensor = nnvm.sym.Variable("tensor", shape=shape)

splited = nnvm.sym.split(tensor, indices_or_sections=2, axis=1)

concatenated = nnvm.sym.concatenate(*splited, axis=1)

shapes = {"tensor": shape}
dtypes = {"tensor": 'float32'}
params = {}

check_model(concatenated, shapes, dtypes, params)


# def test_mlp():
# mlp, params = testing.mlp.get_workload(1)
# shapes = { "data": (10, 3, 224, 224) }
Expand Down

0 comments on commit b6bde89

Please sign in to comment.