From 4d64ff2c84f068b0e59e2d003389c7d56e531f34 Mon Sep 17 00:00:00 2001 From: "Steven S. Lyubomirsky" Date: Tue, 16 Apr 2019 13:44:30 -0700 Subject: [PATCH] Ensure interpreted functions can take values that are not TensorValues (#3015) --- python/tvm/relay/backend/interpreter.py | 8 +++- .../python/relay/test_backend_interpreter.py | 43 +++++++++++++++++++ 2 files changed, 50 insertions(+), 1 deletion(-) diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index ddcbd79122e0..bb43b278639a 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -24,7 +24,7 @@ from .. import _make, ir_pass from ... import register_func, nd from ..base import NodeBase, register_relay_node -from ..expr import Call, Constant, GlobalVar, Function, const +from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, Function, const from ..scope_builder import ScopeBuilder class Value(NodeBase): @@ -112,6 +112,12 @@ def __init__(self, value): def _arg_to_ast(arg): if isinstance(arg, TensorValue): return Constant(arg.data.copyto(nd.cpu(0))) + elif isinstance(arg, TupleValue): + return Tuple([_arg_to_ast(field) for field in arg.fields]) + elif isinstance(arg, RefValue): + return RefCreate(_arg_to_ast(arg.value)) + elif isinstance(arg, ConstructorValue): + return Call(arg.constructor, [_arg_to_ast(field) for field in arg.fields]) elif isinstance(arg, np.ndarray): return Constant(nd.array(arg)) elif isinstance(arg, Constant): diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index da794e25ab56..5d8ceb4c7bdc 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -19,6 +19,7 @@ import tvm.testing from tvm import relay from tvm.relay.backend.interpreter import Value, TupleValue, TensorValue +from tvm.relay.backend.interpreter import RefValue, ConstructorValue from tvm.relay.scope_builder import ScopeBuilder from tvm.relay import testing, create_executor @@ -156,6 +157,7 @@ def test_tensor_value(): xx = np.ones((1, 10)).astype("float32") check_eval(relay.Function([x], x), [TensorValue(xx)], xx) + def test_kwargs_params(): x = relay.var("x", shape=(1, 10)) y = relay.var("y", shape=(1, 10)) @@ -170,6 +172,46 @@ def test_kwargs_params(): tvm.testing.assert_allclose(res.asnumpy(), x_data + y_data + z_data) +def test_function_taking_adt_ref_tuple(): + mod = relay.Module() + prelude = relay.prelude.Prelude(mod) + intrp = create_executor("debug", mod) + + nil_value = ConstructorValue(prelude.nil, [], []) + cons_value = ConstructorValue(prelude.cons, [ + TensorValue(np.random.rand(1, 10).astype('float32')), + nil_value + ], [relay.TensorType((1, 10), 'float32')]) + + ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32'))) + tuple_value = TupleValue(*[ + TensorValue(np.random.rand(1, 10).astype('float32')) for _ in range(10) + ]) + + id_func = intrp.evaluate(prelude.id) + + res_nil = id_func(nil_value) + assert res_nil.constructor == nil_value.constructor + assert len(res_nil.fields) == 0 + + res_cons = id_func(cons_value) + assert res_cons.constructor == cons_value.constructor + assert len(res_cons.fields) == len(cons_value.fields) + tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(), + cons_value.fields[0].asnumpy()) + assert isinstance(res_cons.fields[1], ConstructorValue) + assert res_cons.fields[1].constructor == prelude.nil + assert len(res_cons.fields[1].fields) == 0 + + res_ref = id_func(ref_value) + tvm.testing.assert_allclose(res_ref.value.asnumpy(), ref_value.value.asnumpy()) + + res_tuple = id_func(tuple_value) + for i in range(10): + tvm.testing.assert_allclose(res_tuple.fields[i].asnumpy(), + tuple_value.fields[i].asnumpy()) + + if __name__ == "__main__": test_id() test_add_const() @@ -181,3 +223,4 @@ def test_kwargs_params(): test_kwargs_params() test_ref() test_tensor_value() + test_function_taking_adt_ref_tuple()