From 491875d1717bf59a991d3f573639a80b521f13dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Wed, 10 Oct 2018 14:25:38 -0700 Subject: [PATCH] [Relay] GetItem (#1861) --- include/tvm/relay/expr.h | 24 +++++++++++++++++-- include/tvm/relay/expr_functor.h | 4 ++++ python/tvm/relay/__init__.py | 1 + python/tvm/relay/expr.py | 8 +++++++ src/relay/ir/debug_printer.cc | 6 +++-- src/relay/ir/expr.cc | 16 +++++++++++++ src/relay/ir/expr_functor.cc | 15 ++++++++++-- src/relay/pass/alpha_eq.cc | 9 +++++++ src/relay/pass/type_functor.h | 9 ++++--- src/relay/pass/type_infer.cc | 18 ++++++++++++++ tests/python/relay/test_ir_debug_printer.py | 7 +++++- tests/python/relay/test_ir_nodes.py | 8 +++++++ tests/python/relay/test_ir_well_formed.py | 18 +++++++++++++- tests/python/relay/test_pass_alpha_equal.py | 8 +++++++ .../relay/test_pass_dead_code_elimination.py | 16 +++++++++++++ tests/python/relay/test_pass_free_vars.py | 11 +++++++++ tests/python/relay/test_type_infer.py | 12 ++++++++++ 17 files changed, 177 insertions(+), 13 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 909b702bc1a1..c6e5573d9413 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -360,8 +360,6 @@ class IfNode : public ExprNode { /*! \brief The expression evaluated when condition is false */ Expr false_branch; - IfNode() {} - void VisitAttrs(tvm::AttrVisitor* v) final { v->Visit("cond", &cond); v->Visit("true_branch", &true_branch); @@ -378,6 +376,28 @@ class IfNode : public ExprNode { RELAY_DEFINE_NODE_REF(If, IfNode, Expr); +/*! \brief Get a field out of a tuple. */ +class TupleGetItem; +class TupleGetItemNode : public ExprNode { + public: + /*! \brief The tuple */ + Expr tuple; + /*! \brief which value to get */ + int index; + + void VisitAttrs(tvm::AttrVisitor* v) final { + v->Visit("tuple", &tuple); + v->Visit("index", &index); + } + + TVM_DLL static TupleGetItem make(Expr tuple, int index); + + static constexpr const char * _type_key = "relay.GetItem"; + TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode); +}; + +RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); + /*! \brief Print a debug representation of the expression to the stream. * \param env The environment. * \param e The expression diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 1da66bc95f57..be174d33b4c8 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -89,6 +89,7 @@ class ExprFunctor { Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { throw Error(std::string("Do not have a default for ") + op->type_key()); } @@ -108,6 +109,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(LetNode); RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); + RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); return vtable; } }; @@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor { void VisitExpr_(const LetNode* op) override; void VisitExpr_(const IfNode* op) override; void VisitExpr_(const OpNode* op) override; + void VisitExpr_(const TupleGetItemNode* op) override; virtual void VisitType(const Type& t); }; @@ -153,6 +156,7 @@ class ExprMutator Expr VisitExpr_(const CallNode* call_node) override; Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const IfNode* op) override; + Expr VisitExpr_(const TupleGetItemNode* op) override; /*! \brief Used to visit the types inside of expressions. * * Can be overloaded to transform the types in arbitrary diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index dd48d213f700..18c02a416d6b 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -39,3 +39,4 @@ Call = expr.Call Let = expr.Let If = expr.If +TupleGetItem = expr.TupleGetItem diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 9b292a74eccd..05214ca095d1 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -125,4 +125,12 @@ def __init__(self, cond, true_value, false_value): self.__init_handle_by_constructor__( _make.If, cond, true_value, false_value) +@register_relay_node +class TupleGetItem(Expr): + """An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details.""" + + def __init__(self, tuple_, index): + self.__init_handle_by_constructor__( + _make.TupleGetItem, tuple_, index) + debug_print = _expr._debug_print diff --git a/src/relay/ir/debug_printer.cc b/src/relay/ir/debug_printer.cc index e216faa0f195..90e82d3b2dd7 100644 --- a/src/relay/ir/debug_printer.cc +++ b/src/relay/ir/debug_printer.cc @@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor { } Doc VisitExpr_(const CallNode* c) final { - auto args = DocifyExprArray(c->args); return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">"); } @@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor { return DocOfStr(o->name); } + Doc VisitExpr_(const TupleGetItemNode* g) final { + return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index)); + } + public: ExprDocifier(const Environment& env) : env(env), td(env) { } @@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) { TVM_REGISTER_API("relay._expr._debug_print") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef x = args[1]; - std::cout << x << std::endl; if (x.as()) { *ret = PrintType(args[0], Downcast(x)); } else { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index dbbb5b84fc8b..6b56cb4e844f 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) << ", " << node->false_branch << ")"; }); +TupleGetItem TupleGetItemNode::make(Expr tuple, int index) { + NodePtr n = make_node(); + n->tuple = std::move(tuple); + n->index = index; + return TupleGetItem(n); +} + +TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = TupleGetItemNode::make(args[0], args[1]); +}); + +TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) +.set_dispatch([](const TupleGetItemNode* node, tvm::IRPrinter* p) { + p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; +}); + } // namespace relay } // namespace tvm diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index e3393bdb039b..792f99d699dd 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) { } } -Type ExprMutator::VisitType(const Type& t) { - return t; +Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { + auto t = this->Mutate(g->tuple); + if (g->tuple == t) { + return GetRef(g); + } else { + return TupleGetItemNode::make(t, g->index); + } } +Type ExprMutator::VisitType(const Type& t) { return t; } + void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) { } @@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) { void ExprVisitor::VisitExpr_(const OpNode* op) { return; } +void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { + this->VisitExpr(op->tuple); +} + void ExprVisitor::VisitType(const Type& t) { return; } } // namespace relay diff --git a/src/relay/pass/alpha_eq.cc b/src/relay/pass/alpha_eq.cc index 3c4c3d78063f..0e13a598ca3a 100644 --- a/src/relay/pass/alpha_eq.cc +++ b/src/relay/pass/alpha_eq.cc @@ -335,6 +335,15 @@ struct AlphaEq : ExprFunctor { equal = false; } } + + void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final { + if (const TupleGetItemNode* proj = e2.as()) { + this->VisitExpr(op->tuple, proj->tuple); + equal = equal && (op->index == proj->index); + } else { + equal = false; + } + } }; bool AlphaEqual(const Expr& e1, const Expr& e2) { diff --git a/src/relay/pass/type_functor.h b/src/relay/pass/type_functor.h index a451fbe16984..70a2d9347eab 100644 --- a/src/relay/pass/type_functor.h +++ b/src/relay/pass/type_functor.h @@ -8,7 +8,6 @@ #include #include -#include #include namespace tvm { @@ -21,11 +20,11 @@ class TypeFunctor; #define TYPE_FUNCTOR_DEFAULT \ { return VisitTypeDefault_(op, std::forward(args)...); } -#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ - vtable.template set_dispatch( \ - [](const NodeRef& n, TSelf* self, Args... args) { \ +#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \ + vtable.template set_dispatch( \ + [](const NodeRef& n, TSelf* self, Args... args) { \ return self->VisitType_(static_cast(n.node_.get()), \ - std::forward(args)...); \ + std::forward(args)...); \ }); template diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 1e2100fa902e..72bdaf69f061 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor { return TupleTypeNode::make(fields); } + Type VisitExpr_(const TupleGetItemNode* op) final { + // TODO(M.K.) + // handle case where field type is not known + Type tuple_type = GetType(op->tuple); + auto tuple_ty_node = tuple_type.as(); + if (!tuple_ty_node) { + LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef(op); + } + if (static_cast(tuple_ty_node->fields.size()) <= op->index) { + LOG(FATAL) << "tuple not big enough" << GetRef(op); + } + return tuple_ty_node->fields[op->index]; + } + Type VisitExpr_(const OpNode* op) final { return op->op_type; } @@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } + Expr VisitExpr_(const TupleGetItemNode* op) final { + return AttachCheckedType(op); + } + Expr VisitExpr_(const ParamNode* op) final { return ExprMutator::VisitExpr_(op); } diff --git a/tests/python/relay/test_ir_debug_printer.py b/tests/python/relay/test_ir_debug_printer.py index 2ea0b7575ff8..e5f9ad2e69cd 100644 --- a/tests/python/relay/test_ir_debug_printer.py +++ b/tests/python/relay/test_ir_debug_printer.py @@ -77,7 +77,7 @@ def test_call(): def test_let(): lv = relay.Var('x') - ty = relay.ty.TensorType((10, 20), "float32") + ty = relay.ty.TensorType((10, 20), 'float32') arr = tvm.nd.array(10) value = relay.Constant(arr) let = relay.Let(lv, value, lv, ty) @@ -90,3 +90,8 @@ def test_if(): right = relay.Var('right') ife = relay.If(cond, left, right) show(ife) + +def test_tuple_get_item(): + t = relay.Var('t') + g = relay.TupleGetItem(t, 0) + show(g) diff --git a/tests/python/relay/test_ir_nodes.py b/tests/python/relay/test_ir_nodes.py index d3dae9b2c3f8..79883ed225e0 100644 --- a/tests/python/relay/test_ir_nodes.py +++ b/tests/python/relay/test_ir_nodes.py @@ -175,6 +175,13 @@ def test_if(): str(ife) +def test_tuple_get_item(): + tup = relay.Var("tuple") + get = relay.TupleGetItem(tup, 1) + assert get.tuple == tup + assert get.index == 1 + str(get) + if __name__ == "__main__": test_bad_constructor() test_span() @@ -192,3 +199,4 @@ def test_if(): test_call() test_let() test_if() + test_tuple_get_item() diff --git a/tests/python/relay/test_ir_well_formed.py b/tests/python/relay/test_ir_well_formed.py index 8bdef4d0edb5..c6cb99662bb5 100644 --- a/tests/python/relay/test_ir_well_formed.py +++ b/tests/python/relay/test_ir_well_formed.py @@ -3,7 +3,7 @@ from tvm.relay.ir_pass import well_formed def test_well_formed(): - x = relay.Var("x") + x = relay.Var('x') assert well_formed(x) v = relay.Constant(tvm.nd.array(10)) ty = None @@ -16,3 +16,19 @@ def test_well_formed(): # but we want all binder to be distinct from each other. assert not well_formed(relay.Let(relay.Var("y"), f, relay.Let(relay.Var("z"), f, v, ty), ty)) + + +def test_tuple(): + x = relay.Var('x') + assert well_formed(x) + v = relay.Constant(tvm.nd.array(10)) + ty = None + let = relay.Let(x, v, x, ty) + assert well_formed(let) + assert well_formed(relay.Tuple([v, v])) + assert not well_formed(relay.Tuple([let, let])) + + +def test_tuple_get_item(): + t = relay.Var('t') + assert well_formed(relay.TupleGetItem(t, 2)) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 93f8a8fbc0b3..9fa1a554a6e2 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -167,11 +167,19 @@ def test_type_relation_alpha_equal(): assert bigger != diff_num_inputs +def test_tuple_get_item_alpha_equal(): + x = relay.Var('x') + y = relay.Var('y') + assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1)) + assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2)) + assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1)) if __name__ == "__main__": test_tensor_type_alpha_equal() test_incomplete_type_alpha_equal() + test_constant_alpha_equal() test_type_param_alpha_equal() test_func_type_alpha_equal() test_tuple_type_alpha_equal() test_type_relation_alpha_equal() + test_tuple_get_item_alpha_equal() diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index db73fb5c585f..ce9bda3d254f 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -4,6 +4,7 @@ from tvm.relay.ir_builder import convert, IRBuilder from tvm.relay.op import log, add, equal, subtract + class env: def __init__(self): self.a = relay.Var("a") @@ -22,20 +23,25 @@ def __init__(self): self.two = convert(2.0) self.three = convert(3.0) + e = env() + def test_let(): orig = relay.Let(e.x, e.y, e.z, e.tt) assert alpha_equal(dead_code_elimination(orig), e.z) + def test_used_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt) assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt)) + def test_chain_unused_let(): orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt) assert alpha_equal(dead_code_elimination(orig), e.e) + # make sure we dont infinite loop def test_recursion(): """ @@ -60,14 +66,23 @@ def test_recursion(): assert alpha_equal(dead_code_elimination(orig), orig) assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three) + def test_op_let(): assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two)) + def test_if(): orig = relay.If(convert(True), e.a, e.b) assert alpha_equal(dead_code_elimination(orig), e.a) +def test_tuple_get_item(): + t = relay.Var('t') + g = relay.TupleGetItem(t, 0) + assert alpha_equal(dead_code_elimination(g), g) + assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g) + + if __name__ == "__main__": test_let() test_used_let() @@ -75,3 +90,4 @@ def test_if(): test_recursion() test_op_let() test_if() + test_tuple_get_item() diff --git a/tests/python/relay/test_pass_free_vars.py b/tests/python/relay/test_pass_free_vars.py index 002646ada582..989c9f8d25db 100644 --- a/tests/python/relay/test_pass_free_vars.py +++ b/tests/python/relay/test_pass_free_vars.py @@ -15,6 +15,17 @@ def test_free_vars(): f = relay.Function([relay.Param(x, ty)], ty, x) assert len(free_vars(f)) == 0 + +def test_tuple(): + t = relay.Var('t') + fv = free_vars(relay.Tuple([t, t])) + assert len(fv) == 1 + assert fv[0] == t + fv = free_vars(relay.TupleGetItem(t, 123)) + assert len(fv) == 1 + assert fv[0] == t + + def test_free_type_vars(): tp = relay.TypeParam("") ty = relay.TupleType([tp, relay.TensorType([], "int32")]) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 6629932921f8..77b04590df59 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -9,6 +9,7 @@ from tvm.relay.env import Environment from tvm.relay.op import log, add, equal, subtract, concatenate from tvm.relay.expr import Function +from tvm import relay def assert_has_type(expr, typ, env=Environment({})): checked_expr = infer_type(env, expr) @@ -110,6 +111,16 @@ def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) { fn_ty = func_type([tensor_type(3, 2), tensor_type(2, 2)], tensor_type(5, 2)) assert_decl_has_type(ib.env, try_concat2, fn_ty) +def test_tuple(): + ib = IRBuilder() + dup = ib.global_var('dup') + x = ib.param('x') + with ib.decl(dup, x): + ib.ret(relay.Tuple([x, x])) + # todo: why is this not generalized? + fn_ty = func_type([tensor_type()], relay.TupleType([tensor_type(), tensor_type()])) + assert_decl_has_type(ib.env, dup, fn_ty) + if __name__ == "__main__": test_dual_op() test_recursion() @@ -117,3 +128,4 @@ def try_concat2(x: Float(3, 2), y: Float(2, 2)) -> Float(5, 2) { test_decl() test_recursion() test_concat() + test_tuple()