Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
  • Loading branch information
MarisaKirisame committed Feb 7, 2019
1 parent d685e71 commit dbb8306
Show file tree
Hide file tree
Showing 16 changed files with 44 additions and 43 deletions.
14 changes: 7 additions & 7 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -435,8 +435,8 @@ class TupleGetItemNode : public ExprNode {
RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr);

/*! \brief Create a new Reference out of initial value. */
class RefNew;
class RefNewNode : public ExprNode {
class RefCreate;
class RefCreateNode : public ExprNode {
public:
/*! \brief The initial value of the Reference. */
Expr value;
Expand All @@ -447,13 +447,13 @@ class RefNewNode : public ExprNode {
v->Visit("_checked_type_", &checked_type_);
}

TVM_DLL static RefNew make(Expr value);
TVM_DLL static RefCreate make(Expr value);

static constexpr const char* _type_key = "relay.RefNew";
TVM_DECLARE_NODE_TYPE_INFO(RefNewNode, ExprNode);
static constexpr const char* _type_key = "relay.RefCreate";
TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode);
};

RELAY_DEFINE_NODE_REF(RefNew, RefNewNode, Expr);
RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr);

/*! \brief Get value out of Reference. */
class RefRead;
Expand All @@ -476,7 +476,7 @@ class RefReadNode : public ExprNode {

RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr);

/*! \brief Set value of Reference. The whole expression evaluate to an Empty Tuple. */
/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */
class RefWrite;
class RefWriteNode : public ExprNode {
public:
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
virtual R VisitExpr_(const OpNode* op,
Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefNewNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT;
virtual R VisitExprDefault_(const Node* op, Args...) {
Expand All @@ -111,7 +111,7 @@ class ExprFunctor<R(const Expr& n, Args...)> {
RELAY_EXPR_FUNCTOR_DISPATCH(IfNode);
RELAY_EXPR_FUNCTOR_DISPATCH(OpNode);
RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefNewNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode);
RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode);
return vtable;
Expand Down Expand Up @@ -139,7 +139,7 @@ class ExprVisitor
void VisitExpr_(const IfNode* op) override;
void VisitExpr_(const OpNode* op) override;
void VisitExpr_(const TupleGetItemNode* op) override;
void VisitExpr_(const RefNewNode* op) override;
void VisitExpr_(const RefCreateNode* op) override;
void VisitExpr_(const RefReadNode* op) override;
void VisitExpr_(const RefWriteNode* op) override;
virtual void VisitType(const Type& t);
Expand Down Expand Up @@ -177,7 +177,7 @@ class ExprMutator
Expr VisitExpr_(const LetNode* op) override;
Expr VisitExpr_(const IfNode* op) override;
Expr VisitExpr_(const TupleGetItemNode* op) override;
Expr VisitExpr_(const RefNewNode* op) override;
Expr VisitExpr_(const RefCreateNode* op) override;
Expr VisitExpr_(const RefReadNode* op) override;
Expr VisitExpr_(const RefWriteNode* op) override;
/*!
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
Let = expr.Let
If = expr.If
TupleGetItem = expr.TupleGetItem
RefNew = expr.RefNew
RefCreate = expr.RefCreate
RefRead = expr.RefRead
RefWrite = expr.RefWrite

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def visit_call(self, call):
def visit_op(self, _):
raise Exception("can not compile op in non-eta expanded form")

def visit_ref_new(self, _):
def visit_ref_create(self, _):
raise RuntimeError("reference not supported")

def visit_ref_read(self, _):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,15 +328,15 @@ def __init__(self, tuple_value, index):


@register_relay_node
class RefNew(Expr):
class RefCreate(Expr):
"""Create a new reference from initial value.
Parameters
----------
value: tvm.relay.Expr
The initial value.
"""
def __init__(self, value):
self.__init_handle_by_constructor__(_make.RefNew, value)
self.__init_handle_by_constructor__(_make.RefCreate, value)


@register_relay_node
Expand Down
10 changes: 5 additions & 5 deletions python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def visit(self, expr):
res = self.visit_constant(expr)
elif isinstance(expr, Op):
res = self.visit_op(expr)
elif isinstance(expr, RefNew):
res = self.visit_ref_new(expr)
elif isinstance(expr, RefCreate):
res = self.visit_ref_create(expr)
elif isinstance(expr, RefRead):
res = self.visit_ref_read(expr)
elif isinstance(expr, RefWrite):
Expand Down Expand Up @@ -87,7 +87,7 @@ def visit_op(self, _):
def visit_constant(self, _):
raise NotImplementedError()

def visit_ref_new(self, _):
def visit_ref_create(self, _):
raise NotImplementedError()

def visit_ref_write(self, _):
Expand Down Expand Up @@ -159,8 +159,8 @@ def visit_constructor(self, con):
def visit_match(self, m):
return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern])

def visit_ref_new(self, r):
return RefNew(self.visit(r.value))
def visit_ref_create(self, r):
return RefCreate(self.visit(r.value))

def visit_ref_write(self, r):
return RefWrite(self.visit(r.ref), self.visit(r.value))
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ class Interpreter :
}
}

Value VisitExpr_(const RefNewNode* op) final {
Value VisitExpr_(const RefCreateNode* op) final {
return RefValueNode::make(Eval(op->value));
}

Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,8 @@ class AlphaEqualHandler:
}
}

bool VisitExpr_(const RefNewNode* op, const Expr& e2) final {
if (const RefNewNode* nr = e2.as<RefNewNode>()) {
bool VisitExpr_(const RefCreateNode* op, const Expr& e2) final {
if (const RefCreateNode* nr = e2.as<RefCreateNode>()) {
return ExprEqual(op->value, nr->value);
} else {
return false;
Expand Down
17 changes: 9 additions & 8 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,19 +271,19 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")";
});

RefNew RefNewNode::make(Expr value) {
NodePtr<RefNewNode> n = make_node<RefNewNode>();
RefCreate RefCreateNode::make(Expr value) {
NodePtr<RefCreateNode> n = make_node<RefCreateNode>();
n->value = std::move(value);
return RefNew(n);
return RefCreate(n);
}

TVM_REGISTER_API("relay._make.RefNew").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefNewNode::make(args[0]);
TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefCreateNode::make(args[0]);
});

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<RefNewNode>([](const RefNewNode* node, tvm::IRPrinter* p) {
p->stream << "RefNewNode(" << node->value << ")";
.set_dispatch<RefCreateNode>([](const RefCreateNode* node, tvm::IRPrinter* p) {
p->stream << "RefCreateNode(" << node->value << ")";
});

RefRead RefReadNode::make(Expr ref) {
Expand All @@ -309,7 +309,8 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) {
return RefWrite(n);
}

TVM_REGISTER_API("relay._make.RefWrite").set_body([](TVMArgs args, TVMRetValue* ret) {
TVM_REGISTER_API("relay._make.RefWrite")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = RefWriteNode::make(args[0], args[1]);
});

Expand Down
6 changes: 3 additions & 3 deletions src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,12 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) {
}
}

Expr ExprMutator::VisitExpr_(const RefNewNode* op) {
Expr ExprMutator::VisitExpr_(const RefCreateNode* op) {
Expr value = this->Mutate(op->value);
if (value.same_as(op->value)) {
return GetRef<Expr>(op);
} else {
return RefNewNode::make(value);
return RefCreateNode::make(value);
}
}

Expand Down Expand Up @@ -254,7 +254,7 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {
this->VisitExpr(op->tuple);
}

void ExprVisitor::ExprVisitor::VisitExpr_(const RefNewNode* op) {
void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) {
this->VisitExpr(op->value);
}

Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,8 @@ class RelayHashHandler:
return hash;
}

size_t VisitExpr_(const RefNewNode* rn) final {
size_t hash = std::hash<std::string>()(RefNewNode::_type_key);
size_t VisitExpr_(const RefCreateNode* rn) final {
size_t hash = std::hash<std::string>()(RefCreateNode::_type_key);
hash = Combine(hash, ExprHash(rn->value));
return hash;
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/ir/text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,11 @@ class TextPrinter :
return id;
}

TextValue VisitExpr_(const RefNewNode* op) final {
TextValue VisitExpr_(const RefCreateNode* op) final {
TextValue value = GetValue(op->value);
TextValue id = this->AllocTempVar();
this->PrintIndent();
stream_ << id << " = " << "RefNew(" << op->value << ")";
stream_ << id << " = " << "RefCreate(" << op->value << ")";
this->PrintEndInst("\n");
return id;
}
Expand Down
2 changes: 1 addition & 1 deletion src/relay/pass/fuse_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor {
this->AddNode(op);
}

void VisitExpr_(const RefNewNode* op) final {
void VisitExpr_(const RefCreateNode* op) final {
this->Update(op->value, nullptr, kOpaque);
ExprVisitor::VisitExpr_(op);
this->AddNode(op);
Expand Down
4 changes: 2 additions & 2 deletions src/relay/pass/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)> {
return solver_.Resolve(ret);
}

Type VisitExpr_(const RefNewNode* op) final {
Type VisitExpr_(const RefCreateNode* op) final {
return RefTypeNode::make(GetType(op->value));
}

Expand Down Expand Up @@ -497,7 +497,7 @@ class TypeInferencer::Resolver : public ExprMutator {
return AttachCheckedType(op);
}

Expr VisitExpr_(const RefNewNode* op) final {
Expr VisitExpr_(const RefCreateNode* op) final {
return AttachCheckedType(op);
}

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_backend_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_ref():
body = relay.Let(uv, relay.RefRead(i), body)
body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body)
body = relay.Let(iv, relay.RefRead(i), body)
body = relay.Let(i, relay.RefNew(relay.const(1)), body)
body = relay.Let(i, relay.RefCreate(relay.const(1)), body)
mod[three_with_ref] = relay.Function([], body)
check_eval(three_with_ref, [], 3, mod=mod)

Expand Down
2 changes: 1 addition & 1 deletion tests/python/relay/test_type_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_tuple():
def test_ref():
x = relay.var("x", "float32")
y = relay.var("y", "float32")
r = relay.RefNew(x)
r = relay.RefCreate(x)
st = relay.scalar_type("float32")
assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st)
g = relay.RefRead(r)
Expand Down

0 comments on commit dbb8306

Please sign in to comment.