diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 1fa48d372f9e8..06c89f58a09aa 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -227,6 +227,26 @@ struct StructuralHash { size_t operator()(const Expr& expr) const; }; +/*! \brief turn a dataflow graph into A Normal Form. + * + * It will turn an expression that is in a graph form (with sharing implicit), + * to an expression with explicit sharing (A Normal Form). + * + * The scope of the root expression is the global scope. + + * The scope of any non root expression is the least common ancestor of all it's scope. + * + * Values are ordered by post-DFS order in each scope. + * + * \param e the expression to observably share + * + * \param mod The module used for referencing global functions, can be + * None. + * + * \return expression in A Normal Form + */ +Expr ToANF(const Expr& e, const Module& mod); + } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index 612d140bcc030..7d36c8713dc50 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit): ---------- expr : tvm.relay.Expr The input expression. + fvisit : function The visitor function to be applied. """ @@ -35,7 +36,6 @@ def infer_type(expr, mod=None): mod: Optional[tvm.relay.Module] The global module. - Returns ------- checked_expr : tvm.relay.Expr @@ -112,11 +112,11 @@ def check_kind(t, mod=None): Parameters ---------- - t: tvm.relay.Type + t : tvm.relay.Type The type to check - mod: tvm.relay.Module, optional - The global module + mod : Optional[tvm.relay.Module] + The global module. Returns ------- @@ -414,3 +414,29 @@ def collect_device_annotation_ops(expr): annotation expressions. """ return _ir_pass.CollectDeviceAnnotationOps(expr) + + +def to_anf(expr, mod=None): + """ + Turn Graph Normal Form expression into A Normal Form Expression. + + The scope of the root expression is the global scope. + + The scope of any non root expression is the least common ancestor of all it's scope. + + Values are ordered by post-DFS order in each scope. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression + + mod: Optional[tvm.relay.Module] + The global module. + + Returns + ------- + expr: tvm.relay.Expr + The output expression + """ + return _ir_pass.to_anf(expr, mod) diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 904ceab36c3d4..2fecc8ba3727e 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -36,6 +36,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Var pv, Expr expr) { + CHECK(!used_); lets_.emplace_back(std::make_pair(pv, expr)); return pv; } @@ -71,11 +72,13 @@ class LetList { * * \return the wrapped expr. */ - Expr Get(const Expr& body) const { + Expr Get(const Expr& body) { + CHECK(!used_); Expr ret = body; for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) { ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret); } + used_ = true; return ret; } @@ -108,6 +111,7 @@ class LetList { private: std::vector > lets_; + bool used_ = false; }; } // namespace relay diff --git a/src/relay/pass/to_anf.cc b/src/relay/pass/to_anf.cc new file mode 100644 index 0000000000000..e40ae31e3298e --- /dev/null +++ b/src/relay/pass/to_anf.cc @@ -0,0 +1,475 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file to_anf.cc + * + * \brief Turn implicit sharing into observable sharing. + */ +#include +#include +#include "let_list.h" +#include "../../common/arena.h" + +namespace tvm { +namespace relay { + +using common::LinkNode; +using common::LinkedList; + +/* DependencyGraph track where an expr is used, in the output node. + * It allow us to traverse the graph in reverse order. + */ +class DependencyGraph { + public: + struct Node; + /*! \brief A node in the graph. */ + struct Node { + LinkedList output; + }; + /*! \brief The node map that maps node to graph */ + std::unordered_map node_map; + std::unordered_map node_unmap; + /*! + * \brief create a dependency graph. + * \param arena The arena used for data allocation. + * \param body The body of the expression to create a graph. + */ + static DependencyGraph Create(common::Arena* arena, const Expr& body); + + private: + class Creator; +}; + +// Creator of DependencyGraph +class DependencyGraph::Creator : private ExprVisitor { + public: + explicit Creator(common::Arena* arena) + : arena_(arena) {} + + DependencyGraph Create(const Expr& body) { + this->VisitExpr(body); + return std::move(graph_); + } + + private: + /*! \brief allocator of all the internal node object */ + common::Arena* arena_; + // The output. + DependencyGraph graph_; + // Update the message stored at the node. + void Depend(const Expr& parent, const Expr& child) { + if (graph_.node_map.count(parent) == 0) { + graph_.node_map[parent] = arena_->make(); + graph_.node_unmap[graph_.node_map[parent]] = parent; + } + + if (graph_.node_map.count(child) == 0) { + graph_.node_map[child] = arena_->make(); + graph_.node_unmap[graph_.node_map[child]] = child; + } + + auto* parent_link = arena_->make >(); + parent_link->value = graph_.node_map[parent]; + graph_.node_map[child]->output.Push(parent_link); + + VisitExpr(child); + } + + std::unordered_set visited_; + + void VisitExpr(const Expr& e) final { + if (visited_.count(e) == 0) { + if (graph_.node_map.count(e) == 0) { + graph_.node_map[e] = arena_->make(); + graph_.node_unmap[graph_.node_map[e]] = e; + } + visited_.insert(e); + ExprFunctor::VisitExpr(e); + } + } + + void VisitExpr_(const CallNode* c) final { + Expr e = GetRef(c); + Depend(e, c->op); + for (const auto& a : c->args) { + Depend(e, a); + } + } + + void VisitExpr_(const TupleNode* t) final { + Expr e = GetRef(t); + for (const auto& a : t->fields) { + Depend(e, a); + } + } + + void VisitExpr_(const TupleGetItemNode* t) final { + Expr e = GetRef(t); + Depend(e, t->tuple); + } + + void VisitExpr_(const IfNode* i) final { + Expr e = GetRef(i); + Depend(e, i->cond); + Depend(e, i->true_branch); + Depend(e, i->false_branch); + } + + void VisitExpr_(const FunctionNode* f) final { + Expr e = GetRef(f); + Depend(e, f->body); + for (const auto& p : f->params) { + Depend(e, p); + } + } + + void VisitExpr_(const LetNode* l) final { + Expr e = GetRef(l); + Depend(e, l->var); + Depend(e, l->value); + Depend(e, l->body); + } + + void VisitExpr_(const VarNode* v) final { } + + void VisitExpr_(const GlobalVarNode* v) final { } + + void VisitExpr_(const ConstantNode* c) final { } + + void VisitExpr_(const OpNode* o) final { } +}; + +DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) { + return Creator(arena).Create(body); +} + +Expr ToANF(const Expr& e, const Module& m, std::set* gv); + +struct ScopeNode; +using Scope = std::shared_ptr; + +/* Invariant: when parent is null level is 0 + * + * Invariant: when parent is not null level is 1 + parent->level + */ +struct ScopeNode { + size_t level; + Scope parent; + std::shared_ptr ll = std::make_shared(); + explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { } + ScopeNode() : level(0) { } +}; + +Scope ChildScope(const Scope& s) { + return std::make_shared(s); +} + +Scope LCA(Scope lhs, Scope rhs) { + while (lhs != rhs) { + if (lhs->level > rhs->level) { + lhs = lhs->parent; + } else if (lhs->level < rhs->level) { + rhs = rhs->parent; + } else { + lhs = lhs->parent; + rhs = rhs->parent; + } + } + return lhs; +} + +using Edge = std::pair; + +struct EdgeHash { + size_t operator()(const Edge& p) const { + return dmlc::HashCombine(NodeHash()(p.first), NodeHash()(p.second)); + } +}; + +class ExprScopeMap { + public: + Scope GetScope(const Expr& e) const { + CHECK_NE(expr_scope_.count(e), 0); + return expr_scope_.at(e); + } + + Scope GetScope(const Edge& edge) const { + return edge_scope_.count(edge) != 0 ? edge_scope_.at(edge) : GetScope(edge.first); + } + + Scope GetScope(const Expr& e, const LinkedList& es) const { + auto it = es.head; + if (it == nullptr) { + return global_scope_; + } else { + Scope s = GetScope(Edge(dg_.node_unmap.at(it->value), e)); + it = it->next; + for (; it != nullptr; it = it->next) { + s = LCA(s, GetScope(Edge(dg_.node_unmap.at(it->value), e))); + } + return s; + } + } + + // scoped_children should only contain the direct immediate children that has their own scope. + void DeclareSubScope(const Expr& e, const std::vector& scoped_children) { + Scope s = GetScope(e); + std::vector subscope; + for (const Expr& child : scoped_children) { + Scope ns = ChildScope(s); + subscope.push_back(ns); + Edge edge(e, child); + CHECK_EQ(edge_scope_.count(edge), 0); + edge_scope_.insert({edge, ns}); + } + CHECK_EQ(expr_subscope_.count(e), 0); + expr_subscope_.insert({e, subscope}); + } + + void DeclareScope(const Expr& e, const Scope& s) { + CHECK_EQ(expr_scope_.count(e), 0); + expr_scope_.insert({e, s}); + } + + Scope GetSubScope(const Expr& e, size_t i) const { + return expr_subscope_.at(e).at(i); + } + + explicit ExprScopeMap(const DependencyGraph& dg) : dg_(dg) { } + + private: + // Scope of an expression. + std::unordered_map expr_scope_; + // subscopes of an expression. + // For example, conditional create two subscopes, one for each case. + // The conditional use the original scope. + std::unordered_map, NodeHash, NodeEqual> expr_subscope_; + // Scope of an edge. + // Note that it might not be stored here if it is the same as the parent's scope. + std::unordered_map edge_scope_; + + Scope global_scope_ = std::make_shared(); + + const DependencyGraph& dg_; +}; + +class CalcScope : ExprVisitor { + public: + static ExprScopeMap Calculate(const DependencyGraph& dg, const Expr& e) { + CalcScope cs(dg); + cs(e); + return std::move(cs.esm_); + } + + private: + struct CalcSubScope : ExprFunctor(const Expr&)> { + std::vector VisitExprDefault_(const Node* e) final { + return {}; + } + + std::vector VisitExpr_(const FunctionNode* f) final { + return {f->body}; + } + + std::vector VisitExpr_(const LetNode* l) final { + return {l->body}; + } + + std::vector VisitExpr_(const IfNode* i) final { + return {i->true_branch, i->false_branch}; + } + }; + + std::unordered_set visited_; + const DependencyGraph& dg_; + CalcSubScope css; + ExprScopeMap esm_; + explicit CalcScope(const DependencyGraph& dg) : dg_(dg), esm_(dg) { } + void CalculateAux(const Expr& e) { + if (visited_.count(e) == 0) { + visited_.insert(e); + const auto & l = dg_.node_map.at(e)->output; + for (auto it = l.head; it != nullptr; it = it->next) { + CalculateAux(dg_.node_unmap.at(it->value)); + } + esm_.DeclareScope(e, esm_.GetScope(e, dg_.node_map.at(e)->output)); + esm_.DeclareSubScope(e, css(e)); + } + } + + void VisitExpr(const Expr& e) final { + CalculateAux(e); + ExprVisitor::VisitExpr(e); + } +}; + +bool IsPrimitiveFunction(const Expr& e) { + return e.as() && Downcast(e)->IsPrimitive(); +} + +class Fill : ExprFunctor { + public: + static Expr ToANF(const ExprScopeMap& esm, + const Module& m, + std::set* gv, + const Expr& e) { + Fill fi(esm, m, gv); + return esm.GetScope(e)->ll->Get(fi.VisitExpr(e)); + } + + private: + const ExprScopeMap& esm_; + Module mod_; + std::set* visited_; + + Fill(const ExprScopeMap& esm, + Module mod, + std::set* visited) : + esm_(esm), + mod_(mod), + visited_(visited) { } + + std::unordered_map memo; + + Expr VisitExpr(const Expr& e, const Var& v) final { + if (memo.count(e) == 0) { + memo.insert({e, ExprFunctor::VisitExpr(e, v)}); + } + return memo.at(e); + } + + Expr VisitExpr(const Expr& e) { + Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(TypeVarNode::kType)); + return this->VisitExpr(e, v); + } + + Expr Compound(const Expr& orig, const Expr& now, const Var& v) { + return esm_.GetScope(orig)->ll->Push(v, now); + } + + Expr VisitExpr_(const CallNode* c, const Var& v) final { + Expr e = GetRef(c); + std::vector args; + for (const auto& a : c->args) { + args.push_back(VisitExpr(a)); + } + return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), v); + } + + Expr VisitExpr_(const TupleNode* t, const Var& v) final { + Expr e = GetRef(t); + std::vector fields; + for (const auto& a : t->fields) { + fields.push_back(VisitExpr(a)); + } + return Compound(e, TupleNode::make(fields), v); + } + + Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { + Expr e = GetRef(t); + return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v); + } + + Expr VisitExpr_(const IfNode* i, const Var& v) final { + Expr e = GetRef(i); + Expr ret = IfNode::make(VisitExpr(i->cond), + esm_.GetSubScope(e, 0)->ll->Get(VisitExpr(i->true_branch)), + esm_.GetSubScope(e, 1)->ll->Get(VisitExpr(i->false_branch))); + return Compound(e, ret, v); + } + + Expr VisitExpr_(const FunctionNode* f, const Var& v) final { + Expr e = GetRef(f); + Expr ret; + if (IsPrimitiveFunction(e)) { + ret = e; + } else { + ret = FunctionNode::make(f->params, + esm_.GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), + f->ret_type, + f->type_params, + f->attrs); + } + return Compound(e, ret, v); + } + + Expr VisitExpr_(const LetNode* l, const Var& v) final { + Expr e = GetRef(l); + VisitExpr(l->value, l->var); + Expr ret = esm_.GetSubScope(e, 1)->ll->Get(VisitExpr(l->body)); + return Compound(e, ret, v); + } + + Expr VisitExpr_(const ConstantNode* c, const Var& v) final { + Expr e = GetRef(c); + return Compound(e, e, v); + } + + Expr VisitExpr_(const VarNode* vn, const Var& v) final { + return GetRef(vn); + } + + Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { + GlobalVar gv = GetRef(gvn); + if (visited_->count(gv) == 0) { + visited_->insert(gv); + mod_->Update(gv, Downcast(relay::ToANF(mod_->Lookup(gv), mod_, visited_))); + } + return gv; + } + + Expr VisitExpr_(const OpNode* op, const Var& v) final { + return GetRef(op); + } +}; + +Expr ToANFAux(const Expr& e, const Module& m, std::set* gv) { + /* When you lift a lambda, what is inside is also being lift. + * + * So we must determine the scope of the lambda before determining the scope of it's body. + * + * To make this more principled, + * we always determine the scope of parent before determining the scope of children. + * + * So we calculate all the dependency between nodes. + */ + common::Arena arena; + DependencyGraph dg = DependencyGraph::Create(&arena, e); + /* In order to model new subscopes created by lambda, if else and pattern matching, + * we also assign scope to edge as well. + * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope. + * + * So, the scope of the whole expr is global. + * The scope of any subexpr, is the lowest common ancestor of all incoming edge. + * + * Every scope additionally contain a LetList which collect all value of that scope. + * We do an additional pass to fill all the LetList and we are done. + */ + ExprScopeMap esm = CalcScope::Calculate(dg, e); + return Fill::ToANF(esm, m, gv, e); +} + +Expr ToANF(const Expr& e, const Module& m, std::set* gv) { + if (auto f = e.as()) { + return FunctionNode::make(f->params, + ToANFAux(f->body, m, gv), + f->ret_type, + f->type_params, + f->attrs); + } else { + return ToANFAux(e, m, gv); + } +} + +Expr ToANF(const Expr& e, const Module& m) { + std::set gv; + return ToANF(e, m, &gv); +} + +TVM_REGISTER_API("relay._ir_pass.to_anf") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ToANF(args[0], args[1]); + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index f74aaf74e4748..b88f6500de1a8 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -62,9 +62,9 @@ def test_recursion(): relay.Call(f, [subtract(n, relay.const(1.0)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) - orig = relay.Let(f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) + orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) assert alpha_equal(dead_code_elimination(orig), orig) - assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three) + assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three) def test_op_let(): diff --git a/tests/python/relay/test_to_anf.py b/tests/python/relay/test_to_anf.py new file mode 100644 index 0000000000000..5da7e38a81f51 --- /dev/null +++ b/tests/python/relay/test_to_anf.py @@ -0,0 +1,106 @@ +import numpy as np +import tvm +from tvm import relay +from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type +from tvm.relay import op, create_executor +from tvm.relay.backend.interpreter import Value, TupleValue + + +def check_eval(expr, expected_result, mod=None, rtol=1e-07): + ctx = tvm.context("llvm", 0) + intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + + result = intrp.evaluate(expr) + np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) + + +def test_explicit_bound(): + x = relay.const(1) + y = op.add(x, x) + z = op.add(y, y) + f = relay.Function([], op.add(z, z)) + assert not "let" in f.astext() # assert the values are implicitly bounded + anf = to_anf(f) + assert "let" in anf.astext() # assert the values are explicitly bounded + check_eval(f(), 8.0) + check_eval(anf(), 8.0) + + +# test that the construction order does not matter, +# and is instead ordered by the scope and by post-dfs ordering. +def test_order(): + z = relay.const(3) + y = relay.const(2) + x = relay.const(1) + val = x + y * z + check_eval(val, 7.0) + anf = infer_type(to_anf(val)) + a = relay.Var('a', relay.IncompleteType()) + b = relay.Var('b', relay.IncompleteType()) + c = relay.Var('c', relay.IncompleteType()) + d = relay.Var('d', relay.IncompleteType()) + e = relay.Var('e', relay.IncompleteType()) + expected_output = e + expected_output = relay.Let(e, a + d, expected_output) + expected_output = relay.Let(d, b * c, expected_output) + expected_output = relay.Let(c, z, expected_output) + expected_output = relay.Let(b, y, expected_output) + expected_output = relay.Let(a, x, expected_output) + expected_output = infer_type(expected_output) + assert alpha_equal(anf, expected_output) + + +def test_if(): + cond = relay.const(True) + x = relay.If(cond, relay.const(2), relay.const(3)) + anf = infer_type(to_anf(x)) + a = relay.Var('a', relay.IncompleteType()) + b = relay.Var('b', relay.IncompleteType()) + c = relay.Var('c', relay.IncompleteType()) + d = relay.Var('d', relay.IncompleteType()) + true_branch = relay.Let(a, relay.const(2), a) + false_branch = relay.Let(b, relay.const(3), b) + expected_output = relay.If(c, true_branch, false_branch) + expected_output = relay.Let(d, expected_output, d) + expected_output = relay.Let(c, cond, expected_output) + expected_output = infer_type(expected_output) + assert alpha_equal(anf, expected_output) + + +# make sure we dont infinite loop. +# it is too large so we wont check for the exact program. +def test_recursion(): + """ + Program: + let sum_twice(n: i32) -> i32 = { + m = (n * 2) + if (n == 0) { + return m; + } else { + return m + sum(n - 1); + } + } + sum_twice(5); + """ + return # cannot be run as fuse_ops need to recursively visit + mod = relay.Module() + i64 = relay.TensorType((), 'int64') + f = relay.GlobalVar("f") + n = relay.Var("n", i64) + m = n * relay.const(2, 'int64') + funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')), + m, + m + f(n - relay.const(1, 'int64'))) + value = relay.Function([n], funcbody, i64, []) + mod[f] = value + check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) + old_f = mod[f] + f = to_anf(f, mod=mod) + check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) + + +if __name__ == '__main__': + test_explicit_bound() + test_order() + test_if() + test_recursion()