From 9ab290f8ca21521211a61a6a8d0ebcfbaf82cbfd Mon Sep 17 00:00:00 2001 From: Wei Chen Date: Thu, 13 Jun 2019 09:21:19 +0800 Subject: [PATCH] Support export ADT value in Python (#3299) * Support export ADT value in Python * Cache original functions * Cleanup * Cleanup --- include/tvm/relay/interpreter.h | 13 ++++-- python/tvm/relay/backend/interpreter.py | 4 +- python/tvm/relay/backend/vm.py | 1 - python/tvm/relay/prelude.py | 1 - python/tvm/relay/testing/nat.py | 12 +++--- src/relay/backend/interpreter.cc | 17 ++++---- src/relay/backend/vm/compiler.cc | 41 ++++++------------- src/relay/backend/vm/vm.cc | 23 ++++------- src/relay/pass/pass_manager.cc | 3 +- tests/python/relay/test_adt.py | 11 +++-- .../python/relay/test_backend_interpreter.py | 12 +++--- .../relay/test_pass_to_a_normal_form.py | 4 +- tests/python/relay/test_vm.py | 8 ++-- 13 files changed, 69 insertions(+), 81 deletions(-) diff --git a/include/tvm/relay/interpreter.h b/include/tvm/relay/interpreter.h index 15c96bb12822..68b7ccab99c7 100644 --- a/include/tvm/relay/interpreter.h +++ b/include/tvm/relay/interpreter.h @@ -182,17 +182,22 @@ RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value); class ConstructorValue; struct ConstructorValueNode : ValueNode { - Constructor constructor; + int tag; tvm::Array fields; + /*! \brief Optional field tracking ADT constructor. */ + Constructor constructor; + void VisitAttrs(tvm::AttrVisitor* v) final { - v->Visit("constructor", &constructor); + v->Visit("tag", &tag); v->Visit("fields", &fields); + v->Visit("constructor", &constructor); } - TVM_DLL static ConstructorValue make(Constructor constructor, - tvm::Array fields); + TVM_DLL static ConstructorValue make(int tag, + tvm::Array fields, + Constructor construtor = {}); static constexpr const char* _type_key = "relay.ConstructorValue"; TVM_DECLARE_NODE_TYPE_INFO(ConstructorValueNode, ValueNode); diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 593cf7cfbdf7..ea25b970f87f 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -73,9 +73,9 @@ class Closure(Value): @register_relay_node class ConstructorValue(Value): - def __init__(self, constructor, fields, types): + def __init__(self, tag, fields, constructor, types): self.__init_handle_by_constructor__( - _make.ConstructorValue, constructor, fields, types) + _make.ConstructorValue, tag, fields, constructor, types) @register_relay_node diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 3b9946a3958d..4cb3d611abd4 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -97,7 +97,6 @@ def _eval_vm(mod, ctx, *args): args: List[tvm.NDArray, np.ndarray] The arguments to evaluate. """ - mod = optimize(mod) args = list(args) assert isinstance(args, list) diff --git a/python/tvm/relay/prelude.py b/python/tvm/relay/prelude.py index c801e490d4cf..da75b9d00e13 100644 --- a/python/tvm/relay/prelude.py +++ b/python/tvm/relay/prelude.py @@ -491,7 +491,6 @@ def load_prelude(self): def __init__(self, mod): self.mod = mod self.load_prelude() - self.define_list_adt() self.define_list_hd() self.define_list_tl() diff --git a/python/tvm/relay/testing/nat.py b/python/tvm/relay/testing/nat.py index 4c0c87ce8a9e..a76a340f113d 100644 --- a/python/tvm/relay/testing/nat.py +++ b/python/tvm/relay/testing/nat.py @@ -151,16 +151,16 @@ def add_nat_definitions(prelude): # helper functions for working with nats -def count(n): +def count(prelude, n): """Takes a ConstructorValue corresponding to a nat ADT and converts it into a Python integer. This is an example of using an ADT value in Python. """ assert isinstance(n, ConstructorValue) - if n.constructor.name_hint == 'z': + if n.tag == prelude.z.tag: return 0 - assert n.constructor.name_hint == 's' - return 1 + count(n.fields[0]) + assert n.tag == prelude.s.tag + return 1 + count(prelude, n.fields[0]) def make_nat_value(prelude, n): @@ -168,8 +168,8 @@ def make_nat_value(prelude, n): constructs a ConstructorValue representing that value as a nat. """ if n == 0: - return ConstructorValue(prelude.z, [], []) - return ConstructorValue(prelude.s, [make_nat_value(prelude, n - 1)], []) + return ConstructorValue(prelude.z.tag, [], None, []) + return ConstructorValue(prelude.s.tag, [make_nat_value(prelude, n - 1)], None, []) def make_nat_expr(prelude, n): diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index d700c2036e21..1cc81d5174a5 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -103,11 +103,13 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "RefValueNode(" << node->value << ")"; }); -ConstructorValue ConstructorValueNode::make(Constructor constructor, - tvm::Array fields) { +ConstructorValue ConstructorValueNode::make(int tag, + tvm::Array fields, + Constructor constructor) { NodePtr n = make_node(); - n->constructor = constructor; + n->tag = tag; n->fields = fields; + n->constructor = constructor; return ConstructorValue(n); } @@ -117,7 +119,7 @@ TVM_REGISTER_API("relay._make.ConstructorValue") TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) .set_dispatch([](const ConstructorValueNode* node, tvm::IRPrinter* p) { - p->stream << "ConstructorValueNode(" << node->constructor + p->stream << "ConstructorValueNode(" << node->tag << "," << node->fields << ")"; }); @@ -448,7 +450,7 @@ class Interpreter : "fusing and lowering"; } if (auto con = call->op.as()) { - return ConstructorValueNode::make(GetRef(con), args); + return ConstructorValueNode::make(con->tag, args, GetRef(con)); } // Now we just evaluate and expect to find a closure. Value fn_val = Eval(call->op); @@ -544,9 +546,8 @@ class Interpreter : const ConstructorValueNode* cvn = v.as(); CHECK(cvn) << "need to be a constructor for match"; CHECK_NE(op->constructor->tag, -1); - CHECK_NE(cvn->constructor->tag, -1); - if (op->constructor->tag == cvn->constructor->tag) { - // todo(M.K.): should use ptr equality but it is broken + CHECK_NE(cvn->tag, -1); + if (op->constructor->tag == cvn->tag) { CHECK_EQ(op->patterns.size(), cvn->fields.size()); for (size_t i = 0; i < op->patterns.size(); ++i) { if (!VisitPattern(op->patterns[i], cvn->fields[i])) { diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 07633fc346ec..9b4ab6b8f6c8 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -80,6 +80,8 @@ struct VMCompilerContext { ConstTensorShapeMap const_tensor_shape_map; // List of lowered functions std::vector lowered_funcs; + // The functions that have been lowered. + std::unordered_map seen_funcs; }; // Compute the constant pool, i.e a mapping from Constant node to constant index. @@ -184,9 +186,6 @@ struct VMCompiler : ExprFunctor { size_t registers_num; CompileEngine engine; - /*! \brief The functions that have been lowered. */ - std::unordered_map seen_funcs; - /*! \brief Global shared meta data */ VMCompilerContext* context; @@ -260,7 +259,7 @@ struct VMCompiler : ExprFunctor { void VisitExpr_(const MatchNode* match_node) { auto match = GetRef(match_node); - LOG(FATAL) << "translation of match nodes to the VM is" + LOG(FATAL) << "translation of match nodes to the VM is " << "currently unsupported" << std::endl; } @@ -280,7 +279,8 @@ struct VMCompiler : ExprFunctor { } void VisitExpr_(const GlobalVarNode* gvar) { - LOG(FATAL) << "Global variables should only appear in the call position"; + // TODO(wweic): Support Load GlobalVar into a register + LOG(FATAL) << "Loading GlobalVar into register is not yet supported"; } void VisitExpr_(const IfNode* if_node) { @@ -405,12 +405,12 @@ struct VMCompiler : ExprFunctor { // TODO(jroesch): support lowered funcs for multiple targets CHECK_EQ(cfunc->funcs.size(), 1); auto op_index = -1; - if (seen_funcs.find(cfunc->funcs[0]) == seen_funcs.end()) { + if (this->context->seen_funcs.find(cfunc->funcs[0]) == this->context->seen_funcs.end()) { op_index = this->context->lowered_funcs.size(); this->context->lowered_funcs.push_back(cfunc->funcs[0]); - seen_funcs[cfunc->funcs[0]] = op_index; + this->context->seen_funcs[cfunc->funcs[0]] = op_index; } else { - op_index = seen_funcs[cfunc->funcs[0]]; + op_index = this->context->seen_funcs[cfunc->funcs[0]]; } Emit(Instruction::InvokePacked(op_index, arity, return_val_count, unpacked_arg_regs)); @@ -429,7 +429,6 @@ struct VMCompiler : ExprFunctor { std::vector args_registers; for (auto arg : call_node->args) { - CHECK(arg.as()) << "found: " << AsText(arg, false) << std::endl << arg; this->VisitExpr(arg); args_registers.push_back(last_register); } @@ -449,18 +448,14 @@ struct VMCompiler : ExprFunctor { auto func = this->context->module->Lookup(global); if (IsClosure(func)) { auto arity = func->params.size(); - std::vector free_var_registers; - for (size_t i = 0; i < arity; ++i) { - free_var_registers.push_back(var_register_map.at(func->params[i])); - } - Emit(Instruction::AllocClosure(it->second, arity, free_var_registers, NewRegister())); + Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); } else { Emit(Instruction::Invoke(it->second, args_registers, NewRegister())); } } else if (auto constructor_node = op.as()) { auto constructor = GetRef(constructor_node); - auto tag = GetConstructorTag(constructor); - Emit(Instruction::AllocDatatype(tag, call_node->args.size(), args_registers, NewRegister())); + Emit(Instruction::AllocDatatype(constructor->tag, call_node->args.size(), args_registers, + NewRegister())); } else if (auto var_node = op.as()) { VisitExpr(GetRef(var_node)); Emit(Instruction::InvokeClosure(last_register, args_registers, NewRegister())); @@ -469,18 +464,6 @@ struct VMCompiler : ExprFunctor { } } - size_t GetConstructorTag(tvm::relay::Constructor constructor) { - auto it = this->context->tag_map.find(constructor); - if (it != this->context->tag_map.end()) { - return it->second; - } else { - auto tag = this->context->tag_map.size(); - this->context->tag_map[constructor] = tag; - this->context->tag_index_map[tag] = constructor; - return tag; - } - } - void VisitExpr_(const FunctionNode* func_node) { if (!func_node->IsPrimitive()) { LOG(FATAL) << "local functions should have been removed by lambda lifting:" << std::endl @@ -549,7 +532,7 @@ void PopulatePackedFuncMap(const std::vector& lowered_funcs, } VMFunction CompileFunc(VMCompilerContext* context, const GlobalVar& var, const Function& func) { - DLOG(INFO) << "CompileFunc: " << std::endl << AsText(func, false) << std::endl; + DLOG(INFO) << "CompileFunc: " << var << std::endl << AsText(func, false) << std::endl; size_t params = func->params.size(); VMCompiler compiler(context); compiler.Compile(func); diff --git a/src/relay/backend/vm/vm.cc b/src/relay/backend/vm/vm.cc index 34d067b9c68c..cf0b952005fc 100644 --- a/src/relay/backend/vm/vm.cc +++ b/src/relay/backend/vm/vm.cc @@ -63,24 +63,21 @@ Object EvaluateModule(const Module& module, const std::vector ctxs, return res; } -Value VMToValue(const relay::Module& module, const relay::Type& type, Object obj) { - CHECK(module.defined() && type.defined()); +Value VMToValue(const relay::Module& module, Object obj) { + CHECK(module.defined()); switch (obj->tag) { case ObjectTag::kTensor: { - CHECK(type.as()) << "VM internal error: return value must be a tensor"; return TensorValueNode::make(ToNDArray(obj)); } case ObjectTag::kDatatype: { - // const auto* tuple_type - // const auto& data_type = obj.AsDatatype(); + const auto& data_type = obj.AsDatatype(); - // tvm::Array fields; - // for (size_t i = 0; i < data_type->fields.size(); ++i) { - // fields.push_back(VMToValue(tag_index_map, data_type->fields[i])); - // } + tvm::Array fields; + for (size_t i = 0; i < data_type->fields.size(); ++i) { + fields.push_back(VMToValue(module, data_type->fields[i])); + } - // return ConstructorValueNode::make(tag_index_map.at(data_type->tag), fields); - LOG(FATAL) << "fix me"; + return ConstructorValueNode::make(data_type->tag, fields); } default: LOG(FATAL) << "unsupported return value of type: " << obj->tag; @@ -141,8 +138,6 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue LOG(FATAL) << "expected function or module"; } - auto return_type = module->Lookup(module->entry_func)->ret_type; - std::vector vm_args; for (auto i = 3; i < args.size(); i++) { Object obj = args[i]; @@ -151,7 +146,7 @@ TVM_REGISTER_API("relay._vm._evaluate_vm").set_body([](TVMArgs args, TVMRetValue auto result = EvaluateModule(module, {ctx}, vm_args); DLOG(INFO) << "Evaluate VM returning: result=" << result->tag; - *ret = VMToValue(module, return_type, result); + *ret = VMToValue(module, result); }); } // namespace vm diff --git a/src/relay/pass/pass_manager.cc b/src/relay/pass/pass_manager.cc index 500bdce742a0..fa79a5e82f9e 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/relay/pass/pass_manager.cc @@ -316,7 +316,8 @@ Module FunctionPassNode::operator()(const Module& mod, Module updated_mod = mod; // Execute the pass function and return a new module. std::vector > updates; - for (const auto& it : mod->functions) { + auto original = mod->functions; + for (const auto& it : original) { auto updated_func = SkipFunction(it.second) ? it.second : pass_func(it.second, updated_mod, pass_ctx); diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index 77f4ab1f16a0..f3a08a869841 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -21,12 +21,15 @@ from tvm.relay.backend.interpreter import Value, TupleValue, ConstructorValue from tvm.relay import testing, create_executor from tvm.relay.prelude import Prelude -from tvm.relay.testing import add_nat_definitions, count, make_nat_value, make_nat_expr +from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr mod = relay.Module() p = Prelude(mod) add_nat_definitions(p) +def count(e): + return count_(p, e) + ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") @@ -91,18 +94,18 @@ def to_list(l): val = l ret = [] while True: - if val.constructor.name_hint == 'cons': + if val.tag == p.cons.tag: ret.append(val.fields[0]) val = val.fields[1] else: - assert val.constructor.name_hint == 'nil' + assert val.tag == p.nil.tag break return ret def tree_to_dict(t): assert isinstance(t, ConstructorValue) ret = {} - assert t.constructor.name_hint == 'rose' + assert t.tag == p.rose.tag ret['member'] = t.fields[0] ret['children'] = [] for subtree in to_list(t.fields[1]): diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 1e5e2310e927..11ce11e48322 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -183,11 +183,11 @@ def test_function_taking_adt_ref_tuple(): prelude = relay.prelude.Prelude(mod) intrp = create_executor("debug", mod) - nil_value = ConstructorValue(prelude.nil, [], []) - cons_value = ConstructorValue(prelude.cons, [ + nil_value = ConstructorValue(prelude.nil.tag, [], prelude.nil, []) + cons_value = ConstructorValue(prelude.cons.tag, [ TensorValue(np.random.rand(1, 10).astype('float32')), nil_value - ], [relay.TensorType((1, 10), 'float32')]) + ], prelude.cons, [relay.TensorType((1, 10), 'float32')]) ref_value = RefValue(TensorValue(np.random.rand(1, 10).astype('float32'))) tuple_value = TupleValue(*[ @@ -197,16 +197,16 @@ def test_function_taking_adt_ref_tuple(): id_func = intrp.evaluate(prelude.id) res_nil = id_func(nil_value) - assert res_nil.constructor == nil_value.constructor + assert res_nil.tag == nil_value.tag assert len(res_nil.fields) == 0 res_cons = id_func(cons_value) - assert res_cons.constructor == cons_value.constructor + assert res_cons.tag == cons_value.tag assert len(res_cons.fields) == len(cons_value.fields) tvm.testing.assert_allclose(res_cons.fields[0].asnumpy(), cons_value.fields[0].asnumpy()) assert isinstance(res_cons.fields[1], ConstructorValue) - assert res_cons.fields[1].constructor == prelude.nil + assert res_cons.fields[1].tag == prelude.nil.tag assert len(res_cons.fields[1].fields) == 0 res_ref = id_func(ref_value) diff --git a/tests/python/relay/test_pass_to_a_normal_form.py b/tests/python/relay/test_pass_to_a_normal_form.py index f395580a3f84..db40c86d4b28 100644 --- a/tests/python/relay/test_pass_to_a_normal_form.py +++ b/tests/python/relay/test_pass_to_a_normal_form.py @@ -142,8 +142,8 @@ def test_nat_add(): ctx = tvm.context("llvm", 0) intrp = create_executor(mod=mod, ctx=ctx, target="llvm") assert mod[add].checked_type == relay.FuncType([nat(), nat()], nat()) - assert count(intrp.evaluate(add(s(z()), s(z())))) == 2 - assert count(intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 + assert count(p, intrp.evaluate(add(s(z()), s(z())))) == 2 + assert count(p, intrp.evaluate(to_a_normal_form(add(s(z()), s(z())), mod))) == 2 assert "let" in mod[add].astext() diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index d727e776cbcd..12e343be02ac 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -185,9 +185,7 @@ def test_tuple_second(): result = veval(f, (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), j_data) -@nottest def test_list_constructor(): - # TODO(wweic): implement pattern match to support this test def to_list(o): if isinstance(o, tvm.relay.backend.interpreter.TensorValue): return [o.data.asnumpy().tolist()] @@ -204,6 +202,11 @@ def to_list(o): cons = p.cons l = p.l + # remove all functions to not have pattern match to pass vm compilation + # TODO(wweic): remove the hack and implement pattern match + for v, _ in mod.functions.items(): + mod[v] = relay.const(0) + one2 = cons(relay.const(1), nil()) one3 = cons(relay.const(2), one2) one4 = cons(relay.const(3), one3) @@ -213,7 +216,6 @@ def to_list(o): result = veval(mod)() obj = to_list(result) - import pdb; pdb.set_trace() tvm.testing.assert_allclose(obj, np.array([3,2,1])) def test_let_tensor():