diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index fc47f4e1b7c8..593cf7cfbdf7 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -118,6 +118,8 @@ def _arg_to_ast(arg): return Constant(arg.data.copyto(nd.cpu(0))) elif isinstance(arg, TupleValue): return Tuple([_arg_to_ast(field) for field in arg.fields]) + elif isinstance(arg, tuple): + return Tuple([_arg_to_ast(field) for field in arg]) elif isinstance(arg, RefValue): return RefCreate(_arg_to_ast(arg.value)) elif isinstance(arg, ConstructorValue): diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index e8a99e14d741..1e5e2310e927 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -217,6 +217,31 @@ def test_function_taking_adt_ref_tuple(): tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(), tuple_value.fields[i].asnumpy()) +def test_tuple_passing(): + x = relay.var('x', type_annotation=relay.ty.TupleType([ + relay.ty.TensorType((), 'int64'), + relay.ty.TensorType((), 'int64')])) + + fn = relay.Function([x], relay.expr.TupleGetItem(x, 0)) + mod = relay.Module({}) + gv = relay.GlobalVar('fn') + mod[gv] = fn + mod.entry_func = gv + mod[gv] = relay.ir_pass.infer_type(mod[gv], mod=mod) + + ctx = tvm.cpu() + target = tvm.target.create('llvm') + exec = relay.create_executor(mod=mod, ctx=ctx, target=target) + f = exec.evaluate(gv) + # First use a Python tuple. + out = f((10, 8)) + tvm.testing.assert_allclose(out.asnumpy(), np.array(10)) + # Second use a tuple value. + value_tuple = TupleValue( + TensorValue(np.array(11)), + TensorValue(np.array(12))) + out = f(value_tuple) + tvm.testing.assert_allclose(out.asnumpy(), np.array(11)) if __name__ == "__main__": test_id() @@ -231,4 +256,5 @@ def test_function_taking_adt_ref_tuple(): test_tensor_value() test_tuple_value() test_tuple_getitem() - test_function_taking_adt_ref_tuple() \ No newline at end of file + test_function_taking_adt_ref_tuple() + test_tuple_passing()