Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] GetItem #1861

Merged
merged 9 commits into from
Oct 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,6 @@ class IfNode : public ExprNode {
/*! \brief The expression evaluated when condition is false */
Expr false_branch;

IfNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
Expand All @@ -378,6 +376,28 @@ class IfNode : public ExprNode {

RELAY_DEFINE_NODE_REF(If, IfNode, Expr);

/*! \brief Get a field out of a tuple. */
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
public:
/*! \brief The tuple */
Expr tuple;
/*! \brief which value to get */
int index;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("tuple", &tuple);
v->Visit("index", &index);
}

TVM_DLL static TupleGetItem make(Expr tuple, int index);

static constexpr const char * _type_key = "relay.GetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

/*! \brief Print a debug representation of the expression to the stream.
* \param env The environment.
* \param e The expression
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -108,6 +109,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
return vtable;
}
};
Expand All @@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t);
};

Expand All @@ -153,6 +156,7 @@ class ExprMutator
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@
Call = expr.Call
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
8 changes: 8 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,12 @@ def __init__(self, cond, true_value, false_value):
self.__init_handle_by_constructor__(
_make.If, cond, true_value, false_value)

@register_relay_node
class TupleGetItem(Expr):
"""An expression that get field from tuple in Relay, see tvm/relay/expr.h for more details."""

def __init__(self, tuple_, index):
self.__init_handle_by_constructor__(
_make.TupleGetItem, tuple_, index)

debug_print = _expr._debug_print
6 changes: 4 additions & 2 deletions src/relay/ir/debug_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
}

Doc VisitExpr_(const CallNode* c) final {
auto args = DocifyExprArray(c->args);
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}

Expand All @@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return DocOfStr(o->name);
}

Doc VisitExpr_(const TupleGetItemNode* g) final {
return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index));
}

public:
ExprDocifier(const Environment& env) : env(env), td(env) { }

Expand Down Expand Up @@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) {
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
std::cout << x << std::endl;
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
Expand Down
16 changes: 16 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< ", " << node->false_branch << ")";
});

TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
NodePtr<TupleGetItemNode> n = make_node<TupleGetItemNode>();
n->tuple = std::move(tuple);
n->index = index;
return TupleGetItem(n);
}

TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TupleGetItemNode::make(args[0], args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode* node, tvm::IRPrinter* p) {
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});

} // namespace relay
} // namespace tvm
15 changes: 13 additions & 2 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) {
}
}

Type ExprMutator::VisitType(const Type& t) {
return t;
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItemNode::make(t, g->index);
}
}

Type ExprMutator::VisitType(const Type& t) { return t; }

void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
}

Expand Down Expand Up @@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) {

void ExprVisitor::VisitExpr_(const OpNode* op) { return; }

void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
this->VisitExpr(op->tuple);
}

void ExprVisitor::VisitType(const Type& t) { return; }

} // namespace relay
Expand Down
9 changes: 9 additions & 0 deletions src/relay/pass/alpha_eq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,15 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal = false;
}
}

void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final {
if (const TupleGetItemNode* proj = e2.as<TupleGetItemNode>()) {
this->VisitExpr(op->tuple, proj->tuple);
equal = equal && (op->index == proj->index);
} else {
equal = false;
}
}
};

bool AlphaEqual(const Expr& e1, const Expr& e2) {
Expand Down
9 changes: 4 additions & 5 deletions src/relay/pass/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/error.h>
#include <string>

namespace tvm {
Expand All @@ -21,11 +20,11 @@ class TypeFunctor;
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }

#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitType_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
std::forward<Args>(args)...); \
});

template <typename R, typename... Args>
Expand Down
18 changes: 18 additions & 0 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return TupleTypeNode::make(fields);
}

Type VisitExpr_(const TupleGetItemNode* op) final {
// TODO(M.K.)
// handle case where field type is not known
Type tuple_type = GetType(op->tuple);
auto tuple_ty_node = tuple_type.as<TupleTypeNode>();
if (!tuple_ty_node) {
LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef<TupleGetItem>(op);
}
if (static_cast<int>(tuple_ty_node->fields.size()) <= op->index) {
LOG(FATAL) << "tuple not big enough" << GetRef<TupleGetItem>(op);
}
return tuple_ty_node->fields[op->index];
}

Type VisitExpr_(const OpNode* op) final {
return op->op_type;
}
Expand Down Expand Up @@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}

Expr VisitExpr_(const TupleGetItemNode* op) final {
return AttachCheckedType(op);
}

Expr VisitExpr_(const ParamNode* op) final {
return ExprMutator::VisitExpr_(op);
}
Expand Down
7 changes: 6 additions & 1 deletion tests/python/relay/test_ir_debug_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_call():

def test_let():
lv = relay.Var('x')
ty = relay.ty.TensorType((10, 20), "float32")
ty = relay.ty.TensorType((10, 20), 'float32')
arr = tvm.nd.array(10)
value = relay.Constant(arr)
let = relay.Let(lv, value, lv, ty)
Expand All @@ -90,3 +90,8 @@ def test_if():
right = relay.Var('right')
ife = relay.If(cond, left, right)
show(ife)

def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
show(g)
8 changes: 8 additions & 0 deletions tests/python/relay/test_ir_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def test_if():
str(ife)


def test_tuple_get_item():
tup = relay.Var("tuple")
get = relay.TupleGetItem(tup, 1)
assert get.tuple == tup
assert get.index == 1
str(get)

if __name__ == "__main__":
test_bad_constructor()
test_span()
Expand All @@ -192,3 +199,4 @@ def test_if():
test_call()
test_let()
test_if()
test_tuple_get_item()
18 changes: 17 additions & 1 deletion tests/python/relay/test_ir_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from tvm.relay.ir_pass import well_formed

def test_well_formed():
x = relay.Var("x")
x = relay.Var('x')
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
Expand All @@ -16,3 +16,19 @@ def test_well_formed():
# but we want all binder to be distinct from each other.
assert not well_formed(relay.Let(relay.Var("y"), f,
relay.Let(relay.Var("z"), f, v, ty), ty))


def test_tuple():
x = relay.Var('x')
assert well_formed(x)
v = relay.Constant(tvm.nd.array(10))
ty = None
let = relay.Let(x, v, x, ty)
assert well_formed(let)
assert well_formed(relay.Tuple([v, v]))
assert not well_formed(relay.Tuple([let, let]))


def test_tuple_get_item():
t = relay.Var('t')
assert well_formed(relay.TupleGetItem(t, 2))
8 changes: 8 additions & 0 deletions tests/python/relay/test_pass_alpha_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,19 @@ def test_type_relation_alpha_equal():

assert bigger != diff_num_inputs

def test_tuple_get_item_alpha_equal():
x = relay.Var('x')
y = relay.Var('y')
assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(y, 1))
assert not alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 2))
assert alpha_equal(relay.TupleGetItem(x, 1), relay.TupleGetItem(x, 1))

if __name__ == "__main__":
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
test_type_param_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
test_tuple_get_item_alpha_equal()
16 changes: 16 additions & 0 deletions tests/python/relay/test_pass_dead_code_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tvm.relay.ir_builder import convert, IRBuilder
from tvm.relay.op import log, add, equal, subtract


class env:
def __init__(self):
self.a = relay.Var("a")
Expand All @@ -22,20 +23,25 @@ def __init__(self):
self.two = convert(2.0)
self.three = convert(3.0)


e = env()


def test_let():
orig = relay.Let(e.x, e.y, e.z, e.tt)
assert alpha_equal(dead_code_elimination(orig), e.z)


def test_used_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c, e.tt), e.tt)
assert alpha_equal(dead_code_elimination(orig), relay.Let(e.c, e.d, e.c, e.tt))


def test_chain_unused_let():
orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e, e.tt), e.tt)
assert alpha_equal(dead_code_elimination(orig), e.e)


# make sure we dont infinite loop
def test_recursion():
"""
Expand All @@ -60,18 +66,28 @@ def test_recursion():
assert alpha_equal(dead_code_elimination(orig), orig)
assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three, e.float32)), e.three)


def test_op_let():
assert alpha_equal(dead_code_elimination(add(relay.Let(e.a, e.one, e.three, e.float32), e.two)), add(e.three, e.two))


def test_if():
orig = relay.If(convert(True), e.a, e.b)
assert alpha_equal(dead_code_elimination(orig), e.a)


def test_tuple_get_item():
t = relay.Var('t')
g = relay.TupleGetItem(t, 0)
assert alpha_equal(dead_code_elimination(g), g)
assert alpha_equal(dead_code_elimination(relay.TupleGetItem(relay.Let(e.a, e.one, t, e.float32), 0)), g)


if __name__ == "__main__":
test_let()
test_used_let()
test_chain_unused_let()
test_recursion()
test_op_let()
test_if()
test_tuple_get_item()
Loading