diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index d2000263479d..93ce2dc92fbd 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -925,39 +925,6 @@ def eliminate_common_subexpr(expr, fskip=None): """ return _ir_pass.eliminate_common_subexpr(expr, fskip) - -def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True): - """ - THIS SHOULD BE USED ONLY FOR DEBUGGING, NOT AS AN INTERCHANGE FORMAT! - USE `.astext()` INSTEAD! - - A version of the pretty printer intended for debugging passes. Contains - advanced printing options. - - Parameters - ---------- - ast : Union[relay.Expr, relay.Module, relay.Type] - The relay fragment to be turned into text. - - show_meta_data : bool - Whether to include meta data section in the text - if there is meta data. - - annotate: Optional[relay.Expr->str] - Optional annotate function to provide additional - information in the comment block. - - gnf : bool - Whether to print in GNF. If it is disabled, pointers are left implicit. - - Returns - ------- - text : str - A text representation of `ast`. - """ - return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf) - - def partial_evaluate(expr): """ Evaluate the static fragment of the code. diff --git a/src/relay/ir/pretty_printer.cc b/src/relay/ir/pretty_printer.cc index 969f08b32e83..f4a830040f70 100644 --- a/src/relay/ir/pretty_printer.cc +++ b/src/relay/ir/pretty_printer.cc @@ -22,12 +22,22 @@ * \file pretty_printer.cc * \brief Pretty printer for Relay programs * Supports ANF, GNF, and metadata. + * + * Inlining heuristics: + * - Always inline: + * - GlobalVar + * - Constant + * - Op + * - Var + * - Otherwise, inline if the node is at the end of a scope and is used at most once. */ + #include #include #include #include "doc.h" #include "type_functor.h" +#include "../pass/dependency_graph.h" #include "../../lang/attr_functor.h" namespace tvm { @@ -135,10 +145,8 @@ class PrettyPrinter : public TypeFunctor, public AttrFunctor { public: - explicit PrettyPrinter(bool GNF, - bool show_meta_data, + explicit PrettyPrinter(bool show_meta_data, runtime::TypedPackedFunc annotate) : - GNF_(GNF), show_meta_data_(show_meta_data), annotate_(annotate) {} @@ -150,10 +158,9 @@ class PrettyPrinter : Doc doc; // additional information in comment. if (annotate_ != nullptr) { - return doc << " // " << annotate_(expr); + return doc << " /* " << annotate_(expr) << " */"; } else if (expr->checked_type_.defined()) { - doc << " // ty="; - return doc << Print(expr->checked_type()); + return doc << " /* ty=" << Print(expr->checked_type()) << " */"; } else { return doc; } @@ -176,13 +183,18 @@ class PrettyPrinter : // print in a new scope doc_stack_.push_back(Doc()); // must print first so doc_stack_.back() reference doesn't become stale - Doc doc = Print(node); + Doc doc = Print(node, false, true); doc = doc_stack_.back() << doc; doc_stack_.pop_back(); return doc; } Doc PrintFinal(const NodeRef& node) { + if (node.as_derived()) { + Expr expr = Downcast(node); + dg_ = DependencyGraph::Create(&arena_, expr); + } + Doc doc; doc << PrintScope(node); if (!meta_.empty()) { @@ -200,9 +212,9 @@ class PrettyPrinter : Doc PrintAttrs(const Attrs& attrs, const Expr& op); - Doc Print(const NodeRef& node, bool meta = false) { + Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) { if (node.as_derived()) { - return PrintExpr(Downcast(node), meta); + return PrintExpr(Downcast(node), meta, try_inline); } else if (node.as_derived()) { return PrintType(Downcast(node), meta); } else if (node.as_derived()) { @@ -308,7 +320,12 @@ class PrettyPrinter : return val; } - inline bool IsAtomicExpr(const Expr& expr) { + bool IsUnique(const Expr& expr) { + return !(dg_.expr_node.at(expr)->parents.head && + dg_.expr_node.at(expr)->parents.head->next); + } + + bool AlwaysInline(const Expr& expr) { return expr.as() || expr.as() || expr.as() || expr.as(); } @@ -316,17 +333,25 @@ class PrettyPrinter : //------------------------------------ // Overload of Expr printing functions //------------------------------------ - Doc PrintExpr(const Expr& expr, bool meta) { + Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) { // Exploit memoization to print GNF. // The first time we visit an expression, we need to allocate a temp var // for it. Every subsequent time we can just use its assigned variable. // This works since hashing uses pointer equality. + + // determine whether to inline + bool inline_expr = AlwaysInline(expr); + if (try_inline) { + inline_expr |= IsUnique(expr); + } + auto it = memo_.find(expr); if (it != memo_.end()) return it->second; + Doc printed_expr; if (meta) { printed_expr = meta_.GetMetaNode(GetRef(expr.get())); - } else if (GNF_ && expr.as()) { + } else if (!inline_expr && expr.as()) { // wrap GNFed let in brackets Doc body; printed_expr << "{"; @@ -335,28 +360,26 @@ class PrettyPrinter : } else { printed_expr = VisitExpr(expr); } - // we choose to inline atomic exprs - if (GNF_ && !IsAtomicExpr(expr)) { - Doc temp_var = AllocTemp(); - memo_[expr] = temp_var; - doc_stack_.back() << temp_var << " = " << printed_expr; - if (expr.as()) { - doc_stack_.back() << PrintOptionalInfo(expr); - } - doc_stack_.back() << "\n"; - return temp_var; - } else if (expr.as()) { + + if (expr.as()) { + printed_expr << PrintOptionalInfo(expr); + } + + // add expr to doc + if (expr.as()) { // This is our first time visiting the var and we hit the VarNode case // in the visitor. Thus the variable is free. doc_stack_.back() << "free_var " << printed_expr << "\n"; // Memoization is done in AllocVar. return memo_[expr]; - } else { + } else if (inline_expr) { memo_[expr] = printed_expr; - if (GNF_ && expr.as()) { - printed_expr << PrintOptionalInfo(expr); - } return printed_expr; + } else { + Doc temp_var = AllocTemp(); + memo_[expr] = temp_var; + doc_stack_.back() << temp_var << " = " << printed_expr << "\n"; + return temp_var; } } @@ -420,8 +443,9 @@ class PrettyPrinter : Doc VisitExpr_(const LetNode* op) final { Doc doc; - doc << "let " << AllocVar(op->var) << " = " << Print(op->value) << "\n"; + doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << "\n"; // we use a scope here so GNF hoisting doesn't escape too far + // and nested, unique lets are not hoisted doc << PrintScope(op->body); return doc; } @@ -456,6 +480,8 @@ class PrettyPrinter : Doc doc; int counter = 0; for (const auto& kv : mod->functions) { + dg_ = DependencyGraph::Create(&arena_, kv.second); + std::ostringstream os; if (counter++ != 0) { doc << "\n"; @@ -664,8 +690,6 @@ class PrettyPrinter : } private: - /*! \brief Whether to use GNF. */ - bool GNF_; /*! \brief Whether to print meta data. */ bool show_meta_data_; /*! \brief additional comment function */ @@ -682,6 +706,10 @@ class PrettyPrinter : TextMetaDataContext meta_; /*! \brief counter of temporary variable */ size_t temp_var_counter_{0}; + /*! \brief arena for dependency graph */ + common::Arena arena_; + /*! \brief dependency graph of the expr */ + DependencyGraph dg_; class AttrPrinter; friend class AttrPrinter; }; @@ -751,25 +779,17 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) { std::string PrettyPrint_(const NodeRef& node, bool show_meta_data, - runtime::TypedPackedFunc annotate, - bool gnf) { + runtime::TypedPackedFunc annotate) { Doc doc; doc << "v0.0.1" << "\n" - << PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node); + << PrettyPrinter(show_meta_data, annotate).PrintFinal(node); return doc.str(); } std::string AsText(const NodeRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate) { - return PrettyPrint_(node, show_meta_data, annotate, true); -} - -std::string PassDebugPrint(const NodeRef& node, - bool show_meta_data, - runtime::TypedPackedFunc annotate, - bool gnf) { - return PrettyPrint_(node, show_meta_data, annotate, gnf); + bool show_meta_data, + runtime::TypedPackedFunc annotate) { + return PrettyPrint_(node, show_meta_data, annotate); } TVM_REGISTER_API("relay._expr.AsText") @@ -777,11 +797,5 @@ TVM_REGISTER_API("relay._expr.AsText") bool, runtime::TypedPackedFunc)>(AsText); -TVM_REGISTER_API("relay._ir_pass.pass_debug_print") -.set_body_typed, - bool)>(PassDebugPrint); - } // namespace relay } // namespace tvm diff --git a/src/relay/pass/dependency_graph.cc b/src/relay/pass/dependency_graph.cc new file mode 100644 index 000000000000..6e25086fe826 --- /dev/null +++ b/src/relay/pass/dependency_graph.cc @@ -0,0 +1,165 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/relay/pass/dependency_graph.cc + * \brief + */ +#include "dependency_graph.h" +#include +#include +#include + +namespace tvm { +namespace relay { + +// Creator of DependencyGraph +class DependencyGraph::Creator : private ExprFunctor { + public: + explicit Creator(common::Arena* arena) + : arena_(arena) {} + + DependencyGraph Create(const Expr& body) { + this->VisitExpr(body); + return std::move(graph_); + } + + private: + /*! \brief allocator of all the internal node object */ + common::Arena* arena_; + // The output. + DependencyGraph graph_; + // Update the message stored at the node. + void Depend(DependencyGraph::Node* parent, const Expr& child) { + VisitExpr(child); + + CHECK_NE(graph_.expr_node.count(child), 0); + + Depend(parent, graph_.expr_node[child]); + } + + void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) { + auto* parent_link = arena_->make >(); + parent_link->value = parent; + child->parents.Push(parent_link); + + auto* child_link = arena_->make >(); + child_link->value = child; + parent->children.Push(child_link); + } + + std::unordered_set visited_; + + DependencyGraph::Node* NewNode(bool new_scope) { + auto* ret = arena_->make(); + ret->new_scope = new_scope; + return ret; + } + + void VisitExpr(const Expr& e) final { + if (visited_.count(e) == 0) { + if (graph_.expr_node.count(e) == 0) { + graph_.expr_node[e] = NewNode(false); + } + visited_.insert(e); + ExprFunctor::VisitExpr(e); + graph_.post_dfs_order.push_back(graph_.expr_node[e]); + } + } + + void VisitExpr_(const CallNode* c) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(c)]; + Depend(n, c->op); + for (const auto& a : c->args) { + Depend(n, a); + } + } + + void VisitExpr_(const TupleNode* t) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(t)]; + for (const auto& a : t->fields) { + Depend(n, a); + } + } + + void VisitExpr_(const TupleGetItemNode* t) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(t)]; + Depend(n, t->tuple); + } + + void VisitExpr_(const RefCreateNode* r) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(r)]; + Depend(n, r->value); + } + + void VisitExpr_(const RefReadNode* r) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(r)]; + Depend(n, r->ref); + } + + void VisitExpr_(const RefWriteNode* r) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(r)]; + Depend(n, r->ref); + Depend(n, r->value); + } + + void VisitExpr_(const IfNode* i) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(i)]; + DependencyGraph::Node* t = NewNode(true); + DependencyGraph::Node* f = NewNode(true); + Depend(n, i->cond); + Depend(n, t); + Depend(n, f); + Depend(t, i->true_branch); + Depend(f, i->false_branch); + graph_.post_dfs_order.push_back(f); + graph_.post_dfs_order.push_back(t); + } + + void VisitExpr_(const FunctionNode* f) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(f)]; + DependencyGraph::Node* b = NewNode(true); + Depend(n, b); + Depend(b, f->body); + graph_.post_dfs_order.push_back(b); + } + + void VisitExpr_(const LetNode* l) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(l)]; + DependencyGraph::Node* b = NewNode(true); + Depend(n, b); + Depend(b, l->value); + Depend(b, l->body); + graph_.post_dfs_order.push_back(b); + } + + void VisitExpr_(const MatchNode* m) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(m)]; + Depend(n, m->data); + std::vector v; + for (const Clause& c : m->clauses) { + DependencyGraph::Node* b = NewNode(true); + Depend(n, b); + Depend(b, c->rhs); + v.push_back(b); + } + for (auto it = v.rbegin(); it != v.rend(); ++it) { + graph_.post_dfs_order.push_back(*it); + } + } + + void VisitExpr_(const VarNode* v) final { } + + void VisitExpr_(const GlobalVarNode* v) final { } + + void VisitExpr_(const ConstantNode* c) final { } + + void VisitExpr_(const OpNode* o) final { } + + void VisitExpr_(const ConstructorNode* c) final { } +}; + +DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) { + return Creator(arena).Create(body); +} + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/dependency_graph.h b/src/relay/pass/dependency_graph.h new file mode 100644 index 000000000000..91cef1ce7cde --- /dev/null +++ b/src/relay/pass/dependency_graph.h @@ -0,0 +1,57 @@ +/*! + * Copyright (c) 2019 by Contributors. + * \file tvm/relay/pass/dependency_graph.h + * \brief + */ +#ifndef TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ +#define TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ + +#include +#include +#include +#include "let_list.h" +#include "../../common/arena.h" + +namespace tvm { +namespace relay { + +using common::LinkNode; +using common::LinkedList; + +/* DependencyGraph track input and output of an Expr. + * Additionally, dummy scope is created to model scope. + * It allow us to traverse the graph in reverse order. + */ +class DependencyGraph { + public: + /*! \brief A node in the graph. */ + struct Node { + // Determine scope boundaries. Used for calculating scopes, not for + // constructing dependency graph. + bool new_scope = false; + // incoming edges + LinkedList children; + // outgoing edges + LinkedList parents; + }; + + /*! \brief Maps a Relay Expr to its node in the dependency graph. */ + std::unordered_map expr_node; + + /*! \brief The dependency graph in post DFS order. */ + std::vector post_dfs_order; + + /*! + * \brief Create a dependency graph. + * \param arena The arena used for data allocation. + * \param body The body of the expression to create a graph. + */ + static DependencyGraph Create(common::Arena* arena, const Expr& body); + + private: + class Creator; +}; + +} // namespace relay +} // namespace tvm +#endif // TVM_RELAY_PASS_DEPENDENCY_GRAPH_H_ diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 5507de471ae5..1f0ed9eff28e 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -29,193 +29,11 @@ #include "let_list.h" #include "../../common/arena.h" #include "pass_util.h" +#include "dependency_graph.h" namespace tvm { namespace relay { -using common::LinkNode; -using common::LinkedList; - -/* DependencyGraph track input and output of an Expr. - * Additionally, dummy scope is created to model scope. - * It allow us to traverse the graph in reverse order. - */ -class DependencyGraph { - public: - /*! \brief A node in the graph. */ - struct Node { - bool new_scope = false; - LinkedList input; - LinkedList output; - }; - - /*! \brief The node map that maps node to graph */ - std::unordered_map expr_node; - - /*! \brief All the nodes in post DFS order */ - std::vector post_dfs_order; - - /*! - * \brief create a dependency graph. - * \param arena The arena used for data allocation. - * \param body The body of the expression to create a graph. - */ - static DependencyGraph Create(common::Arena* arena, const Expr& body); - - private: - class Creator; -}; - -// Creator of DependencyGraph -class DependencyGraph::Creator : private ExprFunctor { - public: - explicit Creator(common::Arena* arena) - : arena_(arena) {} - - DependencyGraph Create(const Expr& body) { - this->VisitExpr(body); - return std::move(graph_); - } - - private: - /*! \brief allocator of all the internal node object */ - common::Arena* arena_; - // The output. - DependencyGraph graph_; - // Update the message stored at the node. - void Depend(DependencyGraph::Node* parent, const Expr& child) { - VisitExpr(child); - - CHECK_NE(graph_.expr_node.count(child), 0); - - Depend(parent, graph_.expr_node[child]); - } - - void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) { - auto* parent_link = arena_->make >(); - parent_link->value = parent; - child->output.Push(parent_link); - - auto* child_link = arena_->make >(); - child_link->value = child; - parent->input.Push(child_link); - } - - std::unordered_set visited_; - - DependencyGraph::Node* NewNode(bool new_scope) { - auto* ret = arena_->make(); - ret->new_scope = new_scope; - return ret; - } - - void VisitExpr(const Expr& e) final { - if (visited_.count(e) == 0) { - if (graph_.expr_node.count(e) == 0) { - graph_.expr_node[e] = NewNode(false); - } - visited_.insert(e); - ExprFunctor::VisitExpr(e); - graph_.post_dfs_order.push_back(graph_.expr_node[e]); - } - } - - void VisitExpr_(const CallNode* c) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(c)]; - Depend(n, c->op); - for (const auto& a : c->args) { - Depend(n, a); - } - } - - void VisitExpr_(const TupleNode* t) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(t)]; - for (const auto& a : t->fields) { - Depend(n, a); - } - } - - void VisitExpr_(const TupleGetItemNode* t) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(t)]; - Depend(n, t->tuple); - } - - void VisitExpr_(const RefCreateNode* r) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(r)]; - Depend(n, r->value); - } - - void VisitExpr_(const RefReadNode* r) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(r)]; - Depend(n, r->ref); - } - - void VisitExpr_(const RefWriteNode* r) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(r)]; - Depend(n, r->ref); - Depend(n, r->value); - } - - void VisitExpr_(const IfNode* i) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(i)]; - DependencyGraph::Node* t = NewNode(true); - DependencyGraph::Node* f = NewNode(true); - Depend(n, i->cond); - Depend(n, t); - Depend(n, f); - Depend(t, i->true_branch); - Depend(f, i->false_branch); - graph_.post_dfs_order.push_back(f); - graph_.post_dfs_order.push_back(t); - } - - void VisitExpr_(const FunctionNode* f) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(f)]; - DependencyGraph::Node* b = NewNode(true); - Depend(n, b); - Depend(b, f->body); - graph_.post_dfs_order.push_back(b); - } - - void VisitExpr_(const LetNode* l) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(l)]; - DependencyGraph::Node* b = NewNode(true); - Depend(n, b); - Depend(b, l->value); - Depend(b, l->body); - graph_.post_dfs_order.push_back(b); - } - - void VisitExpr_(const MatchNode* m) final { - DependencyGraph::Node* n = graph_.expr_node[GetRef(m)]; - Depend(n, m->data); - std::vector v; - for (const Clause& c : m->clauses) { - DependencyGraph::Node* b = NewNode(true); - Depend(n, b); - Depend(b, c->rhs); - v.push_back(b); - } - for (auto it = v.rbegin(); it != v.rend(); ++it) { - graph_.post_dfs_order.push_back(*it); - } - } - - void VisitExpr_(const VarNode* v) final { } - - void VisitExpr_(const GlobalVarNode* v) final { } - - void VisitExpr_(const ConstantNode* c) final { } - - void VisitExpr_(const OpNode* o) final { } - - void VisitExpr_(const ConstructorNode* c) final { } -}; - -DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) { - return Creator(arena).Create(body); -} - Expr ToANormalForm(const Expr& e, const Module& m, std::set* gv); struct ScopeNode; @@ -256,7 +74,7 @@ std::unordered_map CalcScope(const DependencyGrap Scope global_scope = std::make_shared(); for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { DependencyGraph::Node* n = *it; - auto iit = n->output.head; + auto iit = n->parents.head; Scope s; if (iit == nullptr) { s = global_scope; @@ -313,7 +131,7 @@ class Fill : ExprFunctor { Scope GetSubScope(const Expr& e, size_t i) { DependencyGraph::Node* n = dg_.expr_node.at(e); - auto h = n->input.head; + auto h = n->children.head; while (i != 0) { CHECK(h); --i; diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index 4206d68b83bf..f10b258ff3cf 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -50,8 +50,8 @@ def test_env(): text = env.astext() assert "def @myf" in text assert "def @myf" in str(env) - assert "%1 = add(%0, %0) // ty=float32" in text - assert "%1 = add(%0, %0) // ty=float32" in str(env) + assert "add(%0, %0) /* ty=float32 */" in text + assert "add(%0, %0) /* ty=float32 */" in str(env) show(env.astext(annotate=lambda x: str(x.checked_type.dtype))) show(text) @@ -112,7 +112,7 @@ def test_let_if_scope(): f = relay.Function([x, y, cond], result) text = f.astext() - assert text.count("{") == 6 + assert text.count("{") == 4 assert "%cond: bool" in text show(f.astext()) @@ -180,8 +180,19 @@ def test_call_node_order(): "%2 = fn (%x) {\n" " %x\n" "}\n" - "%3 = %2(%1)\n" - "%3") + "%2(%1)") + +def test_let_inlining(): + tup = relay.Tuple([relay.const(0), relay.const(0)]) + x = relay.var("x") + assert relay.Let(x, tup, tup).astext() == SEMVER + \ + ("%0 = (0, 0)\n" + "let %x = %0\n" + "%0") + + assert relay.Let(x, tup, x).astext() == SEMVER + \ + ("let %x = (0, 0)\n" + "%x") if __name__ == "__main__": do_print[0] = True @@ -201,3 +212,4 @@ def test_call_node_order(): test_let_if_scope() test_variable_name() test_call_node_order() + test_let_inlining() diff --git a/tests/python/relay/test_op_level1.py b/tests/python/relay/test_op_level1.py index 94d3b157dd0d..d83f25db1b77 100644 --- a/tests/python/relay/test_op_level1.py +++ b/tests/python/relay/test_op_level1.py @@ -38,7 +38,7 @@ def check_single_op(opfunc, ref): x = relay.var("x", tp) y = opfunc(x) # test printer - assert ("%0 = {}(%x)".format(y.op.name)) in y.astext() + assert ("{}(%x)".format(y.op.name)) in y.astext() # test type inference assert relay.ir_pass.infer_type(y).checked_type == tp @@ -78,7 +78,7 @@ def check_binary_op(opfunc, ref): y = relay.var("y", t2) z = opfunc(x, y) # test printer - assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() + assert ("{}(%x, %y)".format(z.op.name)) in z.astext() assert relay.ir_pass.infer_type(z).checked_type == t1 if ref is not None: diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index 8db90fbf91f0..0e44bf851dc4 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -29,7 +29,7 @@ def check_binary_op(opfunc, ref): y = relay.var("y", t2) z = opfunc(x, y) # test printer - assert ("%0 = {}(%x, %y)".format(z.op.name)) in z.astext() + assert ("{}(%x, %y)".format(z.op.name)) in z.astext() assert relay.ir_pass.infer_type(z).checked_type == t1 if ref is not None: diff --git a/tests/python/relay/test_type_infer.py b/tests/python/relay/test_type_infer.py index 4dfe59b8a6a3..8e047354fafd 100644 --- a/tests/python/relay/test_type_infer.py +++ b/tests/python/relay/test_type_infer.py @@ -44,7 +44,7 @@ def initialize_box_adt(mod): def test_monomorphic_let(): - "Program: let x = 1; x" + "Program: let %x = 1; %x" sb = relay.ScopeBuilder() x = sb.let('x', relay.const(1.0, "float64")) sb.ret(x) @@ -53,7 +53,7 @@ def test_monomorphic_let(): def test_single_op(): - "Program: fn (x : float32) { let t1 = f(x); t1 }" + "Program: fn (%x : float32) { let %t1 = f(%x); %t1 }" x = relay.var('x', shape=[]) func = relay.Function([x], op.log(x)) ttype = relay.TensorType([], dtype='float32') @@ -63,8 +63,9 @@ def test_single_op(): def test_add_broadcast_op(): """ Program: - fn (x: Tensor[(10, 4), f32], y: Tensor[(5, 10, 1), f32]) -> Tensor[(5, 10, 4), f32] { - x + y + fn (%x: Tensor[(10, 4), float32], %y: Tensor[(5, 10, 1), float32]) + -> Tensor[(5, 10, 4), float32] { + %x + %y } """ x = relay.var('x', shape=(10, 4)) @@ -80,10 +81,10 @@ def test_add_broadcast_op(): def test_dual_op(): """Program: - fn (x : Tensor[f32, (10, 10)]) { - let t1 = log(x); - let t2 = add(t1, x); - t1 + fn (%x : Tensor[(10, 10), float32]) { + let %t1 = log(x); + let %t2 = add(%t1, %x); + %t1 } """ tp = relay.TensorType((10, 10), "float32") @@ -99,8 +100,8 @@ def test_dual_op(): def test_decl(): """Program: - def f(x : Tensor[(10, 10), f32]) { - log(x) + def @f(%x : Tensor[(10, 10), float32]) { + log(%x) } """ tp = relay.TensorType((10, 10)) @@ -113,11 +114,11 @@ def f(x : Tensor[(10, 10), f32]) { def test_recursion(): """ Program: - def f(n: i32, data: f32) -> f32 { - if (n == 0) { - data + def @f(%n: int32, %data: float32) -> float32 { + if (%n == 0) { + %data } else { - f(n - 1, log(data)) + @f(%n - 1, log(%data)) } } """ @@ -134,7 +135,7 @@ def f(n: i32, data: f32) -> f32 { sb.ret(f(relay.subtract(n, relay.const(1, ti32)), relay.log(data))) mod = relay.Module() mod[f] = relay.Function([n, data], sb.get()) - assert "%3 = @f(%1, %2)" in mod.astext() + assert "@f(%1, %2) /* ty=float32 */" in mod.astext() assert mod[f].checked_type == relay.FuncType([ti32, tf32], tf32)