Skip to content

Commit

Permalink
move
Browse files Browse the repository at this point in the history
fix test

fix lint

fix test

add more code

fix lint
  • Loading branch information
MarisaKirisame committed Jan 23, 2019
1 parent 0806b69 commit f5258a8
Show file tree
Hide file tree
Showing 25 changed files with 585 additions and 27 deletions.
68 changes: 67 additions & 1 deletion include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,12 +428,78 @@ class TupleGetItemNode : public ExprNode {

TVM_DLL static TupleGetItem make(Expr tuple, int index);

static constexpr const char * _type_key = "relay.TupleGetItem";
static constexpr const char* _type_key = "relay.TupleGetItem";
TVM_DECLARE_NODE_TYPE_INFO(TupleGetItemNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

/*! \brief Create a new Reference out of initial value. */
class RefNew;
class RefNewNode : public ExprNode {
public:
/*! \brief The initial value of the Reference. */
Expr value;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefNew make(Expr value);

static constexpr const char* _type_key = "relay.RefNew";
TVM_DECLARE_NODE_TYPE_INFO(RefNewNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefNew, RefNewNode, Expr);

/*! \brief Get value out of Reference. */
class RefRead;
class RefReadNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefRead make(Expr ref);

static constexpr const char* _type_key = "relay.RefRead";
TVM_DECLARE_NODE_TYPE_INFO(RefReadNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);

/*! \brief Set value of Reference. The whole expression evaluate to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
public:
/*! \brief The Reference Expression. */
Expr ref;
/*! \brief The value to write into. */
Expr value;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("ref", &ref);
v->Visit("value", &value);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefWrite make(Expr ref, Expr value);

static constexpr const char* _type_key = "relay.RefWrite";
TVM_DECLARE_NODE_TYPE_INFO(RefWriteNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefWrite, RefWriteNode, Expr);

/*!
* \brief Base class of the temporary expression.
*
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefNewNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
throw Error(std::string("Do not have a default for ") + op->type_key());
}
Expand All @@ -108,6 +111,9 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefNewNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
return vtable;
}
};
Expand All @@ -133,6 +139,9 @@ class ExprVisitor
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
void VisitExpr_(const RefNewNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
virtual void VisitType(const Type& t);

protected:
Expand Down Expand Up @@ -168,6 +177,9 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
Expr VisitExpr_(const RefNewNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
/*!
* \brief Used to visit the types inside of expressions.
*
Expand Down
19 changes: 19 additions & 0 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,25 @@ struct TensorValueNode : ValueNode {

RELAY_DEFINE_NODE_REF(TensorValue, TensorValueNode, Value);

/*! \brief A reference value. */
class RefValue;

struct RefValueNode : ValueNode {
mutable Value value;

RefValueNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
}

TVM_DLL static RefValue make(Value val);

static constexpr const char* _type_key = "relay.RefValue";
TVM_DECLARE_NODE_TYPE_INFO(RefValueNode, ValueNode);
};

RELAY_DEFINE_NODE_REF(RefValue, RefValueNode, Value);

} // namespace relay
} // namespace tvm
Expand Down
27 changes: 27 additions & 0 deletions include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,33 @@ class TupleTypeNode : public TypeNode {

RELAY_DEFINE_NODE_REF(TupleType, TupleTypeNode, Type);

/*!
* \brief The type of reference values.
*/
class RefType;
/*!
* \brief Reference Type in relay.
*/
class RefTypeNode : public TypeNode {
public:
/*! \brief The type of value in the Reference. */
Type value;

RefTypeNode() {}

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("value", &value);
v->Visit("span", &span);
}

TVM_DLL static RefType make(Type value);

static constexpr const char* _type_key = "relay.RefType";
TVM_DECLARE_NODE_TYPE_INFO(RefTypeNode, TypeNode);
};

RELAY_DEFINE_NODE_REF(RefType, RefTypeNode, Type);

class TypeReporter;

/*!
Expand Down
12 changes: 8 additions & 4 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
TypeRelation = ty.TypeRelation
IncompleteType = ty.IncompleteType
scalar_type = ty.scalar_type
RefType = ty.RefType

# Expr
Expr = expr.Expr
Expand All @@ -55,15 +56,18 @@
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator
RefNew = expr.RefNew
RefRead = expr.RefRead
RefWrite = expr.RefWrite

# helper functions
var = expr.var
const = expr.const
bind = expr.bind

# ExprFunctor
ExprFunctor = expr_functor.ExprFunctor
ExprMutator = expr_functor.ExprMutator

# Parser
fromtext = parser.fromtext
9 changes: 9 additions & 0 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,15 @@ def visit_call(self, call):
def visit_op(self, _):
raise Exception("can not compile op in non-eta expanded form")

def visit_ref_new(self, _):
raise RuntimeError("reference not supported")

def visit_ref_read(self, _):
raise RuntimeError("reference not supported")

def visit_ref_write(self, _):
raise RuntimeError("reference not supported")

def _get_json(self):
"""
Convert the sequence of nodes stored by the compiler into the
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __repr__(self):
def __iter__(self):
return iter(self.fields)


@register_relay_node
class Closure(Value):
"""A closure produced by the interpreter."""
Expand Down Expand Up @@ -79,6 +80,13 @@ def __str__(self):
return str(self.data)


@register_relay_node
class RefValue(Value):
def __init__(self, value):
self.__init_handle_by_constructor__(
_make.RefValue, value)


def _arg_to_ast(arg):
if isinstance(arg, TensorValue):
return Constant(arg.data.copyto(_nd.cpu(0)))
Expand Down
40 changes: 40 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,46 @@ def __init__(self, tuple_value, index):
_make.TupleGetItem, tuple_value, index)


@register_relay_node
class RefNew(Expr):
"""Create a new reference from initial value.
Parameters
----------
value: tvm.relay.Expr
The initial value.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefNew, value)


@register_relay_node
class RefRead(Expr):
"""Get the value inside the reference.
Parameters
----------
ref: tvm.relay.Expr
The reference.
"""
def __init__(self, ref):
self.__init_handle_by_constructor__(_make.RefRead, ref)


@register_relay_node
class RefWrite(Expr):
"""
Update the value inside the reference.
The whole expression will evaluate to an empty tuple.
Parameters
----------
ref: tvm.relay.Expr
The reference.
value: tvm.relay.Expr
The new value.
"""
def __init__(self, ref, value):
self.__init_handle_by_constructor__(_make.RefWrite, ref, value)


class TempExpr(Expr):
"""Baseclass of all TempExpr.
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ def visit(self, expr):
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
elif isinstance(expr, RefNew):
res = self.visit_ref_new(expr)
elif isinstance(expr, RefRead):
res = self.visit_ref_read(expr)
elif isinstance(expr, RefWrite):
res = self.visit_ref_write(expr)
else:
raise Exception("warning unhandled case: {0}".format(type(expr)))

Expand Down Expand Up @@ -81,6 +87,14 @@ def visit_op(self, _):
def visit_constant(self, _):
raise NotImplementedError()

def visit_ref_new(self, _):
raise NotImplementedError()

def visit_ref_write(self, _):
raise NotImplementedError()

def visit_ref_read(self, _):
raise NotImplementedError()

class ExprMutator(ExprFunctor):
"""
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/ty.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,19 @@ def __init__(self, func, args, num_inputs, attrs):
func, args, num_inputs, attrs)


@register_relay_node
class RefType(Type):
"""Reference Type in relay.
Parameters
----------
value: Type
The value type.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefType, value)


def scalar_type(dtype):
"""Creates a scalar type.
Expand Down
42 changes: 42 additions & 0 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,23 @@ TVM_REGISTER_API("relay._make.TensorValue")
*ret = TensorValueNode::make(data);
});

RefValue RefValueNode::make(Value value) {
NodePtr<RefValueNode> n = make_node<RefValueNode>();
n->value = value;
return RefValue(n);
}

TVM_REGISTER_API("relay._make.RefValue")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefValueNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefValueNode>([](const RefValueNode* node,
tvm::IRPrinter* p) {
p->stream << "RefValueNode(" << node->value << ")";
});

/*!
* \brief A stack frame in the Relay interpreter.
*
Expand Down Expand Up @@ -432,6 +449,31 @@ class Interpreter :
}
}

Value VisitExpr_(const RefWriteNode* op) final {
Value r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
rv->value = Eval(op->value);
return TupleValueNode::make({});
} else {
LOG(FATAL) << "type error, type system should have caught this";
return Value();
}
}

Value VisitExpr_(const RefNewNode* op) final {
return RefValueNode::make(Eval(op->value));
}

Value VisitExpr_(const RefReadNode* op) final {
Value r = Eval(op->ref);
if (const RefValueNode* rv = r.as<RefValueNode>()) {
return rv->value;
} else {
LOG(FATAL) << "type error, type system should have caught this";
return Value();
}
}

InterpreterState get_state(Expr e = Expr()) const {
InterpreterStateNode::Stack stack;
for (auto fr : this->stack_.frames) {
Expand Down
Loading

0 comments on commit f5258a8

Please sign in to comment.