Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Reference #2489

Merged
merged 3 commits into from
Feb 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 67 additions & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,78 @@ class TupleGetItemNode : public ExprNode {

TVM_DLL static TupleGetItem make(Expr tuple, int index);

static constexpr const char * _type_key = "relay.TupleGetItem";
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

/*! \brief Create a new Reference out of initial value. */
class RefCreate;
class RefCreateNode : public ExprNode {
public:
/*! \brief The initial value of the Reference. */
Expr value;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefCreate make(Expr value);

static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);

/*! \brief Get value out of Reference. */
class RefRead;
class RefReadNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefRead make(Expr ref);

static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);

/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;
/*! \brief The value to write into. */
Expr value;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefWrite make(Expr ref, Expr value);

static constexpr const char* _type_key = "relay.RefWrite";
TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);

/*!
* \brief Base class of the temporary expression.
*
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -108,6 +111,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
return vtable;
}
};
Expand All @@ -133,6 +139,9 @@ class ExprVisitor
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
virtual void VisitType(const Type& t);

protected:
Expand Down Expand Up @@ -168,6 +177,9 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
/*!
* \brief Used to visit the types inside of expressions.
*
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,25 @@ struct TensorValueNode : ValueNode {

RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);

/*! \brief A reference value. */
class RefValue;

struct RefValueNode : ValueNode {
mutable Value value;

RefValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
}

TVM_DLL static RefValue make(Value val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);

} // namespace relay
} // namespace tvm
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,33 @@ class TupleTypeNode : public TypeNode {

RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);

/*!
* \brief The type of reference values.
*/
class RefType;
/*!
* \brief Reference Type in relay.
*/
class RefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;

RefTypeNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
}

TVM_DLL static RefType make(Type value);

static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type);

class TypeReporter;

/*!
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type
RefType = ty.RefType

# Expr
Expr = expr.Expr
Expand All @@ -56,15 +57,18 @@
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator
RefCreate = expr.RefCreate
RefRead = expr.RefRead
RefWrite = expr.RefWrite

# helper functions
var = expr.var
const = expr.const
bind = expr.bind

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator

# Parser
fromtext = parser.fromtext
9 changes: 9 additions & 0 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,15 @@ def visit_call(self, call):
def visit_op(self, _):
raise Exception("can not compile op in non-eta expanded form")

def visit_ref_create(self, _):
raise RuntimeError("reference not supported")

def visit_ref_read(self, _):
raise RuntimeError("reference not supported")

def visit_ref_write(self, _):
raise RuntimeError("reference not supported")

def _get_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __repr__(self):
def __iter__(self):
return iter(self.fields)


@register_relay_node
class Closure(Value):
"""A closure produced by the interpreter."""
Expand Down Expand Up @@ -79,6 +80,13 @@ def __str__(self):
return str(self.data)


@register_relay_node
class RefValue(Value):
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.RefValue, value)


def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(_nd.cpu(0)))
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,46 @@ def __init__(self, tuple_value, index):
_make.TupleGetItem, tuple_value, index)


@register_relay_node
class RefCreate(Expr):
"""Create a new reference from initial value.
Parameters
----------
value: tvm.relay.Expr
The initial value.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefCreate, value)


@register_relay_node
class RefRead(Expr):
"""Get the value inside the reference.
Parameters
----------
ref: tvm.relay.Expr
The reference.
"""
def __init__(self, ref):
self.__init_handle_by_constructor__(_make.RefRead, ref)


@register_relay_node
class RefWrite(Expr):
"""
Update the value inside the reference.
The whole expression will evaluate to an empty tuple.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why choose empty tuple

Parameters
----------
ref: tvm.relay.Expr
The reference.
value: tvm.relay.Expr
The new value.
"""
def __init__(self, ref, value):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value)


class TempExpr(Expr):
"""Baseclass of all TempExpr.

Expand Down
18 changes: 16 additions & 2 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def visit(self, expr):
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
elif isinstance(expr, RefCreate):
res = self.visit_ref_create(expr)
elif isinstance(expr, RefRead):
res = self.visit_ref_read(expr)
elif isinstance(expr, RefWrite):
res = self.visit_ref_write(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))

Expand Down Expand Up @@ -81,6 +87,14 @@ def visit_op(self, _):
def visit_constant(self, _):
raise NotImplementedError()

def visit_ref_create(self, _):
raise NotImplementedError()

def visit_ref_write(self, _):
raise NotImplementedError()

def visit_ref_read(self, _):
raise NotImplementedError()

class ExprMutator(ExprFunctor):
"""
Expand Down Expand Up @@ -145,8 +159,8 @@ def visit_constructor(self, con):
def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])

def visit_ref_new(self, r):
return RefNew(self.visit(r.value))
def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))

def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,19 @@ def __init__(self, func, args, num_inputs, attrs):
func, args, num_inputs, attrs)


@register_relay_node
class RefType(Type):
"""Reference Type in relay.

Parameters
----------
value: Type
The value type.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefType, value)


def scalar_type(dtype):
"""Creates a scalar type.

Expand Down
Loading