Skip to content

Commit

Permalink
[Relay] GetItem (apache#1861)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame authored and tqchen committed Oct 10, 2018
1 parent 34ad094 commit 491875d
Show file tree
Hide file tree
Showing 17 changed files with 177 additions and 13 deletions.
24 changes: 22 additions & 2 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,6 @@ class IfNode : public ExprNode {
/*! \brief The expression evaluated when condition is false */
Expr false_branch;

IfNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
Expand All @@ -378,6 +376,28 @@ class IfNode : public ExprNode {

RELAY_DEFINE_NODE_REF(If, IfNode, Expr);

/*! \brief Get a field out of a tuple. */
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
public:
/*! \brief The tuple */
Expr tuple;
/*! \brief which value to get */
int index;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("tuple", &tuple);
v->Visit("index", &index);
}

TVM_DLL static TupleGetItem make(Expr tuple, int index);

static constexpr const char * _type_key = "relay.GetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

/*! \brief Print a debug representation of the expression to the stream.
* \param env The environment.
* \param e The expression
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -108,6 +109,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
return vtable;
}
};
Expand All @@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t);
};

Expand All @@ -153,6 +156,7 @@ class ExprMutator
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
Call = expr.Call
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
8 changes: 8 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,12 @@ def __init__(self, cond, true_value, false_value):
self.__init_handle_by_constructor__(
_make.If, cond, true_value, false_value)

@register_relay_node
class TupleGetItem(Expr):
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""

def __init__(self, tuple_, index):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_, index)

debug_print = _expr._debug_print
6 changes: 4 additions & 2 deletions src/relay/ir/debug_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
}

Doc VisitExpr_(const CallNode* c) final {
auto args = DocifyExprArray(c->args);
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}

Expand All @@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return DocOfStr(o->name);
}

Doc VisitExpr_(const TupleGetItemNode* g) final {
return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index));
}

public:
ExprDocifier(const Environment& env) : env(env), td(env) { }

Expand Down Expand Up @@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) {
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
std::cout << x << std::endl;
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
Expand Down
16 changes: 16 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< ", " << node->false_branch << ")";
});

TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
NodePtr<TupleGetItemNode> n = make_node<TupleGetItemNode>();
n->tuple = std::move(tuple);
n->index = index;
return TupleGetItem(n);
}

TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleGetItemNode::make(args[0], args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});

} // namespace relay
} // namespace tvm
15 changes: 13 additions & 2 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) {
}
}

Type ExprMutator::VisitType(const Type& t) {
return t;
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItemNode::make(t, g->index);
}
}

Type ExprMutator::VisitType(const Type& t) { return t; }

void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
}

Expand Down Expand Up @@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) {

void ExprVisitor::VisitExpr_(const OpNode* op) { return; }

void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
this->VisitExpr(op->tuple);
}

void ExprVisitor::VisitType(const Type& t) { return; }

} // namespace relay
Expand Down
9 changes: 9 additions & 0 deletions src/relay/pass/alpha_eq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,15 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal = false;
}
}

void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final {
if (const TupleGetItemNode* proj = e2.as<TupleGetItemNode>()) {
this->VisitExpr(op->tuple, proj->tuple);
equal = equal && (op->index == proj->index);
} else {
equal = false;
}
}
};

bool AlphaEqual(const Expr& e1, const Expr& e2) {
Expand Down
9 changes: 4 additions & 5 deletions src/relay/pass/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/error.h>
#include <string>

namespace tvm {
Expand All @@ -21,11 +20,11 @@ class TypeFunctor;
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }

#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitType_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
std::forward<Args>(args)...); \
});

template <typename R, typename... Args>
Expand Down
18 changes: 18 additions & 0 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return TupleTypeNode::make(fields);
}

Type VisitExpr_(const TupleGetItemNode* op) final {
// TODO(M.K.)
// handle case where field type is not known
Type tuple_type = GetType(op->tuple);
auto tuple_ty_node = tuple_type.as<TupleTypeNode>();
if (!tuple_ty_node) {
LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef<TupleGetItem>(op);
}
if (static_cast<int>(tuple_ty_node->fields.size()) <= op->index) {
LOG(FATAL) << "tuple not big enough" << GetRef<TupleGetItem>(op);
}
return tuple_ty_node->fields[op->index];
}

Type VisitExpr_(const OpNode* op) final {
return op->op_type;
}
Expand Down Expand Up @@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}

Expr VisitExpr_(const TupleGetItemNode* op) final {
return AttachCheckedType(op);
}

Expr VisitExpr_(const ParamNode* op) final {
return ExprMutator::VisitExpr_(op);
}
Expand Down
7 changes: 6 additions & 1 deletion tests/python/relay/test_ir_debug_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_call():

def test_let():
lv = relay.Var('x')
ty = relay.ty.TensorType((10, 20), "float32")
ty = relay.ty.TensorType((10, 20), 'float32')
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv, ty)
Expand All @@ -90,3 +90,8 @@ def test_if():
right = relay.Var('right')
ife = relay.If(cond, left, right)
show(ife)

def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
show(g)
8 changes: 8 additions & 0 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def test_if():
str(ife)


def test_tuple_get_item():
tup = relay.Var("tuple")
get = relay.TupleGetItem(tup, 1)
assert get.tuple == tup
assert get.index == 1
str(get)

if __name__ == "__main__":
test_bad_constructor()
test_span()
Expand All @@ -192,3 +199,4 @@ def test_if():
test_call()
test_let()
test_if()
test_tuple_get_item()
18 changes: 17 additions & 1 deletion tests/python/relay/test_ir_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tvm.relay.ir_pass import well_formed

def test_well_formed():
x = relay.Var("x")
x = relay.Var('x')
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
Expand All @@ -16,3 +16,19 @@ def test_well_formed():
# but we want all binder to be distinct from each other.
assert not well_formed(relay.Let(relay.Var("y"), f,
relay.Let(relay.Var("z"), f, v, ty), ty))


def test_tuple():
x = relay.Var('x')
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
let = relay.Let(x, v, x, ty)
assert well_formed(let)
assert well_formed(relay.Tuple([v, v]))
assert not well_formed(relay.Tuple([let, let]))


def test_tuple_get_item():
t = relay.Var('t')
assert well_formed(relay.TupleGetItem(t, 2))
8 changes: 8 additions & 0 deletions tests/python/relay/test_pass_alpha_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,19 @@ def test_type_relation_alpha_equal():

assert bigger != diff_num_inputs

def test_tuple_get_item_alpha_equal():
x = relay.Var('x')
y = relay.Var('y')
assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))

if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
test_type_param_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
test_tuple_get_item_alpha_equal()
16 changes: 16 additions & 0 deletions tests/python/relay/test_pass_dead_code_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tvm.relay.ir_builder import convert, IRBuilder
from tvm.relay.op import log, add, equal, subtract


class env:
def __init__(self):
self.a = relay.Var("a")
Expand All @@ -22,20 +23,25 @@ def __init__(self):
self.two = convert(2.0)
self.three = convert(3.0)


e = env()


def test_let():
orig = relay.Let(e.x, e.y, e.z, e.tt)
assert alpha_equal(dead_code_elimination(orig), e.z)


def test_used_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt))


def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt)
assert alpha_equal(dead_code_elimination(orig), e.e)


# make sure we dont infinite loop
def test_recursion():
"""
Expand All @@ -60,18 +66,28 @@ def test_recursion():
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)


def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two))


def test_if():
orig = relay.If(convert(True), e.a, e.b)
assert alpha_equal(dead_code_elimination(orig), e.a)


def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
assert alpha_equal(dead_code_elimination(g), g)
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g)


if __name__ == "__main__":
test_let()
test_used_let()
test_chain_unused_let()
test_recursion()
test_op_let()
test_if()
test_tuple_get_item()
Loading

0 comments on commit 491875d

Please sign in to comment.