Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] GetItem #1861

Merged
merged 9 commits into from
Oct 10, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,8 +360,6 @@ class IfNode : public ExprNode {
/*! \brief The expression evaluated when condition is false */
Expr false_branch;

IfNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("cond", &cond);
v->Visit("true_branch", &true_branch);
Expand All @@ -378,6 +376,23 @@ class IfNode : public ExprNode {

RELAY_DEFINE_NODE_REF(If, IfNode, Expr);

/*! \brief Get a field out of a tuple. */
class TupleGetItem;
class TupleGetItemNode : public ExprNode {
public:
/*! \brief The tuple */
Expr tuple;
/*! \brief which value to get */
int index;

TVM_DLL static TupleGetItem make(Expr tuple, int index);

static constexpr const char * _type_key = "relay.GetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

/*! \brief Print a debug representation of the expression to the stream.
* \param env The environment.
* \param e The expression
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -108,6 +109,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(LetNode);
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
return vtable;
}
};
Expand All @@ -131,6 +133,7 @@ class ExprVisitor : public ::tvm::relay::ExprFunctor<void(const Expr& n)> {
void VisitExpr_(const LetNode* op) override;
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
virtual void VisitType(const Type& t);
};

Expand All @@ -153,6 +156,7 @@ class ExprMutator
Expr VisitExpr_(const CallNode* call_node) override;
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
/*! \brief Used to visit the types inside of expressions.
*
* Can be overloaded to transform the types in arbitrary
Expand Down
6 changes: 4 additions & 2 deletions src/relay/ir/debug_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
}

Doc VisitExpr_(const CallNode* c) final {
auto args = DocifyExprArray(c->args);
return Docify(c->op) + Seq("<", DocifyExprArray(c->args), ">");
}

Expand All @@ -244,6 +243,10 @@ class ExprDocifier : private ExprFunctor<Doc(const Expr& n)> {
return DocOfStr(o->name);
}

Doc VisitExpr_(const TupleGetItemNode* g) final {
return Docify(g->tuple) + DocOfStr(std::string(".") + std::to_string(g->index));
}

public:
ExprDocifier(const Environment& env) : env(env), td(env) { }

Expand Down Expand Up @@ -291,7 +294,6 @@ std::string PrintType(const Environment& env, const Type& t) {
TVM_REGISTER_API("relay._expr._debug_print")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef x = args[1];
std::cout << x << std::endl;
if (x.as<TypeNode>()) {
*ret = PrintType(args[0], Downcast<Type>(x));
} else {
Expand Down
16 changes: 16 additions & 0 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,21 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
<< ", " << node->false_branch << ")";
});

TupleGetItem TupleGetItemNode::make(Expr tuple, int index) {
NodePtr<TupleGetItemNode> n = make_node<TupleGetItemNode>();
n->tuple = std::move(tuple);
n->index = index;
return TupleGetItem(n);
}

TVM_REGISTER_API("relay._make.TupleGetItem").set_body([](TVMArgs args, TVMRetValue * ret) {
*ret = TupleGetItemNode::make(args[0], args[1]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<TupleGetItemNode>([](const TupleGetItemNode *node, tvm::IRPrinter *p) {
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});

} // namespace relay
} // namespace tvm
15 changes: 13 additions & 2 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,17 @@ Expr ExprMutator::VisitExpr_(const IfNode* op) {
}
}

Type ExprMutator::VisitType(const Type& t) {
return t;
Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
auto t = this->Mutate(g->tuple);
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItemNode::make(t, g->index);
}
}

Type ExprMutator::VisitType(const Type& t) { return t; }

void ExprVisitor::ExprVisitor::VisitExpr_(const VarNode* op) {
}

Expand Down Expand Up @@ -206,6 +213,10 @@ void ExprVisitor::VisitExpr_(const IfNode* op) {

void ExprVisitor::VisitExpr_(const OpNode* op) { return; }

void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
this->VisitExpr(op->tuple);
}

void ExprVisitor::VisitType(const Type& t) { return; }

} // namespace relay
Expand Down
9 changes: 9 additions & 0 deletions src/relay/pass/alpha_eq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,15 @@ struct AlphaEq : ExprFunctor<void(const Expr&, const Expr&)> {
equal = false;
}
}

void VisitExpr_(const TupleGetItemNode* op, const Expr& e2) final {
if (const TupleGetItemNode* proj = e2.as<TupleGetItemNode>()) {
this->VisitExpr(op->tuple, proj->tuple);
equal = equal && (op->index == proj->index);
} else {
equal = false;
}
}
};

bool AlphaEqual(const Expr& e1, const Expr& e2) {
Expand Down
9 changes: 4 additions & 5 deletions src/relay/pass/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include <tvm/node/ir_functor.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/error.h>
#include <string>

namespace tvm {
Expand All @@ -21,11 +20,11 @@ class TypeFunctor;
#define TYPE_FUNCTOR_DEFAULT \
{ return VisitTypeDefault_(op, std::forward<Args>(args)...); }

#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
#define RELAY_TYPE_FUNCTOR_DISPATCH(OP) \
vtable.template set_dispatch<OP>( \
[](const NodeRef& n, TSelf* self, Args... args) { \
return self->VisitType_(static_cast<const OP*>(n.node_.get()), \
std::forward<Args>(args)...); \
std::forward<Args>(args)...); \
});

template <typename R, typename... Args>
Expand Down
18 changes: 18 additions & 0 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,20 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return TupleTypeNode::make(fields);
}

Type VisitExpr_(const TupleGetItemNode* op) final {
// TODO(M.K.)
// handle case where field type is not known
Type tuple_type = GetType(op->tuple);
auto tuple_ty_node = tuple_type.as<TupleTypeNode>();
if (!tuple_ty_node) {
LOG(FATAL) << "only expressions with tuple types is accepted" << GetRef<TupleGetItem>(op);
}
if (static_cast<int>(tuple_ty_node->fields.size()) <= op->index) {
LOG(FATAL) << "tuple not big enough" << GetRef<TupleGetItem>(op);
}
return tuple_ty_node->fields[op->index];
}

Type VisitExpr_(const OpNode* op) final {
return op->op_type;
}
Expand Down Expand Up @@ -293,6 +307,10 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}

Expr VisitExpr_(const TupleGetItemNode* op) final {
return AttachCheckedType(op);
}

Expr VisitExpr_(const ParamNode* op) final {
return ExprMutator::VisitExpr_(op);
}
Expand Down