From dbb8306a7151da002622a3dc5c8aa036c5035a11 Mon Sep 17 00:00:00 2001 From: Marisa Kirisame Date: Wed, 6 Feb 2019 18:12:27 -0800 Subject: [PATCH] address comment --- include/tvm/relay/expr.h | 14 +++++++------- include/tvm/relay/expr_functor.h | 8 ++++---- python/tvm/relay/__init__.py | 2 +- .../tvm/relay/backend/graph_runtime_codegen.py | 2 +- python/tvm/relay/expr.py | 4 ++-- python/tvm/relay/expr_functor.py | 10 +++++----- src/relay/backend/interpreter.cc | 2 +- src/relay/ir/alpha_equal.cc | 4 ++-- src/relay/ir/expr.cc | 17 +++++++++-------- src/relay/ir/expr_functor.cc | 6 +++--- src/relay/ir/hash.cc | 4 ++-- src/relay/ir/text_printer.cc | 4 ++-- src/relay/pass/fuse_ops.cc | 2 +- src/relay/pass/type_infer.cc | 4 ++-- tests/python/relay/test_backend_interpreter.py | 2 +- tests/python/relay/test_type_infer.py | 2 +- 16 files changed, 44 insertions(+), 43 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 3eb1bba8667f..b9a57c5d4618 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -435,8 +435,8 @@ class TupleGetItemNode : public ExprNode { RELAY_DEFINE_NODE_REF(TupleGetItem, TupleGetItemNode, Expr); /*! \brief Create a new Reference out of initial value. */ -class RefNew; -class RefNewNode : public ExprNode { +class RefCreate; +class RefCreateNode : public ExprNode { public: /*! \brief The initial value of the Reference. */ Expr value; @@ -447,13 +447,13 @@ class RefNewNode : public ExprNode { v->Visit("_checked_type_", &checked_type_); } - TVM_DLL static RefNew make(Expr value); + TVM_DLL static RefCreate make(Expr value); - static constexpr const char* _type_key = "relay.RefNew"; - TVM_DECLARE_NODE_TYPE_INFO(RefNewNode, ExprNode); + static constexpr const char* _type_key = "relay.RefCreate"; + TVM_DECLARE_NODE_TYPE_INFO(RefCreateNode, ExprNode); }; -RELAY_DEFINE_NODE_REF(RefNew, RefNewNode, Expr); +RELAY_DEFINE_NODE_REF(RefCreate, RefCreateNode, Expr); /*! \brief Get value out of Reference. */ class RefRead; @@ -476,7 +476,7 @@ class RefReadNode : public ExprNode { RELAY_DEFINE_NODE_REF(RefRead, RefReadNode, Expr); -/*! \brief Set value of Reference. The whole expression evaluate to an Empty Tuple. */ +/*! \brief Set value of Reference. The whole expression evaluates to an Empty Tuple. */ class RefWrite; class RefWriteNode : public ExprNode { public: diff --git a/include/tvm/relay/expr_functor.h b/include/tvm/relay/expr_functor.h index 661ba3245e64..e7b66bc1bbde 100644 --- a/include/tvm/relay/expr_functor.h +++ b/include/tvm/relay/expr_functor.h @@ -89,7 +89,7 @@ class ExprFunctor { virtual R VisitExpr_(const OpNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const TupleGetItemNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; - virtual R VisitExpr_(const RefNewNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; + virtual R VisitExpr_(const RefCreateNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefReadNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExpr_(const RefWriteNode* op, Args... args) EXPR_FUNCTOR_DEFAULT; virtual R VisitExprDefault_(const Node* op, Args...) { @@ -111,7 +111,7 @@ class ExprFunctor { RELAY_EXPR_FUNCTOR_DISPATCH(IfNode); RELAY_EXPR_FUNCTOR_DISPATCH(OpNode); RELAY_EXPR_FUNCTOR_DISPATCH(TupleGetItemNode); - RELAY_EXPR_FUNCTOR_DISPATCH(RefNewNode); + RELAY_EXPR_FUNCTOR_DISPATCH(RefCreateNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefReadNode); RELAY_EXPR_FUNCTOR_DISPATCH(RefWriteNode); return vtable; @@ -139,7 +139,7 @@ class ExprVisitor void VisitExpr_(const IfNode* op) override; void VisitExpr_(const OpNode* op) override; void VisitExpr_(const TupleGetItemNode* op) override; - void VisitExpr_(const RefNewNode* op) override; + void VisitExpr_(const RefCreateNode* op) override; void VisitExpr_(const RefReadNode* op) override; void VisitExpr_(const RefWriteNode* op) override; virtual void VisitType(const Type& t); @@ -177,7 +177,7 @@ class ExprMutator Expr VisitExpr_(const LetNode* op) override; Expr VisitExpr_(const IfNode* op) override; Expr VisitExpr_(const TupleGetItemNode* op) override; - Expr VisitExpr_(const RefNewNode* op) override; + Expr VisitExpr_(const RefCreateNode* op) override; Expr VisitExpr_(const RefReadNode* op) override; Expr VisitExpr_(const RefWriteNode* op) override; /*! diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 96aca22e621a..0af164bc7a73 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -57,7 +57,7 @@ Let = expr.Let If = expr.If TupleGetItem = expr.TupleGetItem -RefNew = expr.RefNew +RefCreate = expr.RefCreate RefRead = expr.RefRead RefWrite = expr.RefWrite diff --git a/python/tvm/relay/backend/graph_runtime_codegen.py b/python/tvm/relay/backend/graph_runtime_codegen.py index 893cbcb36163..cc510b2290cf 100644 --- a/python/tvm/relay/backend/graph_runtime_codegen.py +++ b/python/tvm/relay/backend/graph_runtime_codegen.py @@ -283,7 +283,7 @@ def visit_call(self, call): def visit_op(self, _): raise Exception("can not compile op in non-eta expanded form") - def visit_ref_new(self, _): + def visit_ref_create(self, _): raise RuntimeError("reference not supported") def visit_ref_read(self, _): diff --git a/python/tvm/relay/expr.py b/python/tvm/relay/expr.py index 590f7ab10a9f..71b89d0b4777 100644 --- a/python/tvm/relay/expr.py +++ b/python/tvm/relay/expr.py @@ -328,7 +328,7 @@ def __init__(self, tuple_value, index): @register_relay_node -class RefNew(Expr): +class RefCreate(Expr): """Create a new reference from initial value. Parameters ---------- @@ -336,7 +336,7 @@ class RefNew(Expr): The initial value. """ def __init__(self, value): - self.__init_handle_by_constructor__(_make.RefNew, value) + self.__init_handle_by_constructor__(_make.RefCreate, value) @register_relay_node diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 2d7b84610ea4..b22a4e7562e2 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -41,8 +41,8 @@ def visit(self, expr): res = self.visit_constant(expr) elif isinstance(expr, Op): res = self.visit_op(expr) - elif isinstance(expr, RefNew): - res = self.visit_ref_new(expr) + elif isinstance(expr, RefCreate): + res = self.visit_ref_create(expr) elif isinstance(expr, RefRead): res = self.visit_ref_read(expr) elif isinstance(expr, RefWrite): @@ -87,7 +87,7 @@ def visit_op(self, _): def visit_constant(self, _): raise NotImplementedError() - def visit_ref_new(self, _): + def visit_ref_create(self, _): raise NotImplementedError() def visit_ref_write(self, _): @@ -159,8 +159,8 @@ def visit_constructor(self, con): def visit_match(self, m): return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.pattern]) - def visit_ref_new(self, r): - return RefNew(self.visit(r.value)) + def visit_ref_create(self, r): + return RefCreate(self.visit(r.value)) def visit_ref_write(self, r): return RefWrite(self.visit(r.ref), self.visit(r.value)) diff --git a/src/relay/backend/interpreter.cc b/src/relay/backend/interpreter.cc index ce5d9ac3066b..893e66b41b42 100644 --- a/src/relay/backend/interpreter.cc +++ b/src/relay/backend/interpreter.cc @@ -460,7 +460,7 @@ class Interpreter : } } - Value VisitExpr_(const RefNewNode* op) final { + Value VisitExpr_(const RefCreateNode* op) final { return RefValueNode::make(Eval(op->value)); } diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index e059ad32254f..d0cc004994d4 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -369,8 +369,8 @@ class AlphaEqualHandler: } } - bool VisitExpr_(const RefNewNode* op, const Expr& e2) final { - if (const RefNewNode* nr = e2.as()) { + bool VisitExpr_(const RefCreateNode* op, const Expr& e2) final { + if (const RefCreateNode* nr = e2.as()) { return ExprEqual(op->value, nr->value); } else { return false; diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index b055b3eb90ad..bc6eee3ebc03 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -271,19 +271,19 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) p->stream << "TupleGetItemNode(" << node->tuple << ", " << node->index << ")"; }); -RefNew RefNewNode::make(Expr value) { - NodePtr n = make_node(); +RefCreate RefCreateNode::make(Expr value) { + NodePtr n = make_node(); n->value = std::move(value); - return RefNew(n); + return RefCreate(n); } -TVM_REGISTER_API("relay._make.RefNew").set_body([](TVMArgs args, TVMRetValue* ret) { - *ret = RefNewNode::make(args[0]); +TVM_REGISTER_API("relay._make.RefCreate").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = RefCreateNode::make(args[0]); }); TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable) -.set_dispatch([](const RefNewNode* node, tvm::IRPrinter* p) { - p->stream << "RefNewNode(" << node->value << ")"; +.set_dispatch([](const RefCreateNode* node, tvm::IRPrinter* p) { + p->stream << "RefCreateNode(" << node->value << ")"; }); RefRead RefReadNode::make(Expr ref) { @@ -309,7 +309,8 @@ RefWrite RefWriteNode::make(Expr ref, Expr value) { return RefWrite(n); } -TVM_REGISTER_API("relay._make.RefWrite").set_body([](TVMArgs args, TVMRetValue* ret) { +TVM_REGISTER_API("relay._make.RefWrite") +.set_body([](TVMArgs args, TVMRetValue* ret) { *ret = RefWriteNode::make(args[0], args[1]); }); diff --git a/src/relay/ir/expr_functor.cc b/src/relay/ir/expr_functor.cc index eb474ef11537..9bdfa00ce298 100644 --- a/src/relay/ir/expr_functor.cc +++ b/src/relay/ir/expr_functor.cc @@ -157,12 +157,12 @@ Expr ExprMutator::VisitExpr_(const TupleGetItemNode* g) { } } -Expr ExprMutator::VisitExpr_(const RefNewNode* op) { +Expr ExprMutator::VisitExpr_(const RefCreateNode* op) { Expr value = this->Mutate(op->value); if (value.same_as(op->value)) { return GetRef(op); } else { - return RefNewNode::make(value); + return RefCreateNode::make(value); } } @@ -254,7 +254,7 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) { this->VisitExpr(op->tuple); } -void ExprVisitor::ExprVisitor::VisitExpr_(const RefNewNode* op) { +void ExprVisitor::ExprVisitor::VisitExpr_(const RefCreateNode* op) { this->VisitExpr(op->value); } diff --git a/src/relay/ir/hash.cc b/src/relay/ir/hash.cc index 9fb94b8cc5dc..d984bb051e43 100644 --- a/src/relay/ir/hash.cc +++ b/src/relay/ir/hash.cc @@ -286,8 +286,8 @@ class RelayHashHandler: return hash; } - size_t VisitExpr_(const RefNewNode* rn) final { - size_t hash = std::hash()(RefNewNode::_type_key); + size_t VisitExpr_(const RefCreateNode* rn) final { + size_t hash = std::hash()(RefCreateNode::_type_key); hash = Combine(hash, ExprHash(rn->value)); return hash; } diff --git a/src/relay/ir/text_printer.cc b/src/relay/ir/text_printer.cc index 527178ce36a2..05179d584d84 100644 --- a/src/relay/ir/text_printer.cc +++ b/src/relay/ir/text_printer.cc @@ -363,11 +363,11 @@ class TextPrinter : return id; } - TextValue VisitExpr_(const RefNewNode* op) final { + TextValue VisitExpr_(const RefCreateNode* op) final { TextValue value = GetValue(op->value); TextValue id = this->AllocTempVar(); this->PrintIndent(); - stream_ << id << " = " << "RefNew(" << op->value << ")"; + stream_ << id << " = " << "RefCreate(" << op->value << ")"; this->PrintEndInst("\n"); return id; } diff --git a/src/relay/pass/fuse_ops.cc b/src/relay/pass/fuse_ops.cc index 1370578a2969..a6298ba448f3 100644 --- a/src/relay/pass/fuse_ops.cc +++ b/src/relay/pass/fuse_ops.cc @@ -278,7 +278,7 @@ class IndexedForwardGraph::Creator : private ExprVisitor { this->AddNode(op); } - void VisitExpr_(const RefNewNode* op) final { + void VisitExpr_(const RefCreateNode* op) final { this->Update(op->value, nullptr, kOpaque); ExprVisitor::VisitExpr_(op); this->AddNode(op); diff --git a/src/relay/pass/type_infer.cc b/src/relay/pass/type_infer.cc index 8da5cb7f7a5d..10ba3b127bbf 100644 --- a/src/relay/pass/type_infer.cc +++ b/src/relay/pass/type_infer.cc @@ -432,7 +432,7 @@ class TypeInferencer : private ExprFunctor { return solver_.Resolve(ret); } - Type VisitExpr_(const RefNewNode* op) final { + Type VisitExpr_(const RefCreateNode* op) final { return RefTypeNode::make(GetType(op->value)); } @@ -497,7 +497,7 @@ class TypeInferencer::Resolver : public ExprMutator { return AttachCheckedType(op); } - Expr VisitExpr_(const RefNewNode* op) final { + Expr VisitExpr_(const RefCreateNode* op) final { return AttachCheckedType(op); } diff --git a/tests/python/relay/test_backend_interpreter.py b/tests/python/relay/test_backend_interpreter.py index 37d78a372d05..801b3068eff0 100644 --- a/tests/python/relay/test_backend_interpreter.py +++ b/tests/python/relay/test_backend_interpreter.py @@ -121,7 +121,7 @@ def test_ref(): body = relay.Let(uv, relay.RefRead(i), body) body = relay.Let(u, relay.RefWrite(i, relay.const(2)), body) body = relay.Let(iv, relay.RefRead(i), body) - body = relay.Let(i, relay.RefNew(relay.const(1)), body) + body = relay.Let(i, relay.RefCreate(relay.const(1)), body) mod[three_with_ref] = relay.Function([], body) check_eval(three_with_ref, [], 3, mod=mod) diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index d5c0d978f424..eeefbc6c3051 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -134,7 +134,7 @@ def test_tuple(): def test_ref(): x = relay.var("x", "float32") y = relay.var("y", "float32") - r = relay.RefNew(x) + r = relay.RefCreate(x) st = relay.scalar_type("float32") assert relay.ir_pass.infer_type(r).checked_type == relay.RefType(st) g = relay.RefRead(r)