From 6e63d286b276fb47522d00a4939f2ffbe6d87e2f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=BE=E9=9B=A8=E9=AD=94=E7=90=86=E6=B2=99?= Date: Thu, 24 Jan 2019 12:49:33 -0800 Subject: [PATCH] [Relay] A Normal Form Canonicalization (#2251) --- include/tvm/relay/pass.h | 20 + python/tvm/relay/ir_pass.py | 42 +- src/relay/pass/let_list.h | 6 +- src/relay/pass/to_anf.cc | 410 ++++++++++++++++++ .../relay/test_pass_dead_code_elimination.py | 4 +- tests/python/relay/test_to_anf.py | 106 +++++ 6 files changed, 577 insertions(+), 11 deletions(-) create mode 100644 src/relay/pass/to_anf.cc create mode 100644 tests/python/relay/test_to_anf.py diff --git a/include/tvm/relay/pass.h b/include/tvm/relay/pass.h index 566d69cc6b0b8..8527ab7a2cb5e 100644 --- a/include/tvm/relay/pass.h +++ b/include/tvm/relay/pass.h @@ -296,6 +296,26 @@ struct StructuralHash { size_t operator()(const Expr& expr) const; }; +/*! \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF). + * + * It will turn an expression that is in a graph form (with sharing implicit), + * to an expression with explicit sharing (A-Normal Form). + * + * The scope of the root expression is the global scope. + + * The scope of any non root expression is the least common ancestor of all it's scope. + * + * Values are ordered by post-DFS order in each scope. + * + * \param e the expression to observably share + * + * \param mod The module used for referencing global functions, can be + * None. + * + * \return expression in A-Normal Form + */ +Expr ToANF(const Expr& e, const Module& mod); + } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/ir_pass.py b/python/tvm/relay/ir_pass.py index d5d5e9261fc79..b27f030e459a2 100644 --- a/python/tvm/relay/ir_pass.py +++ b/python/tvm/relay/ir_pass.py @@ -19,6 +19,7 @@ def post_order_visit(expr, fvisit): ---------- expr : tvm.relay.Expr The input expression. + fvisit : function The visitor function to be applied. """ @@ -35,7 +36,6 @@ def infer_type(expr, mod=None): mod: Optional[tvm.relay.Module] The global module. - Returns ------- checked_expr : tvm.relay.Expr @@ -112,11 +112,11 @@ def check_kind(t, mod=None): Parameters ---------- - t: tvm.relay.Type + t : tvm.relay.Type The type to check - mod: tvm.relay.Module, optional - The global module + mod : Optional[tvm.relay.Module] + The global module. Returns ------- @@ -480,8 +480,35 @@ def collect_device_annotation_ops(expr): return _ir_pass.CollectDeviceAnnotationOps(expr) +def to_anf(expr, mod=None): + """ + Turn Graph Normal Form expression into A Normal Form Expression. + + The scope of the root expression is the global scope. + + The scope of any non root expression is the least common ancestor of all it's scope. + + Values are ordered by post-DFS order in each scope. + + Parameters + ---------- + expr : tvm.relay.Expr + The input expression. + + mod: Optional[tvm.relay.Module] + The global module. + + Returns + ------- + expr: tvm.relay.Expr + The output expression. + """ + return _ir_pass.to_anf(expr, mod) + + def gradient(expr, mod=None): - """. + """ + Transform a function to return original result paired with gradient of input. Parameters ---------- @@ -489,11 +516,10 @@ def gradient(expr, mod=None): The input expression, which is a Function or a GlobalVar. mod : Optional[tvm.relay.Module] - The global module. Returns ------- - ret : tvm.relay.Expr - A function that calculate the original result paired with gradient. + expr : tvm.relay.Expr + The output expression. """ return _ir_pass.first_order_gradient(expr, mod) diff --git a/src/relay/pass/let_list.h b/src/relay/pass/let_list.h index 904ceab36c3d4..2fecc8ba3727e 100644 --- a/src/relay/pass/let_list.h +++ b/src/relay/pass/let_list.h @@ -36,6 +36,7 @@ class LetList { * \return a Var that hold the inserted expr. */ Var Push(Var pv, Expr expr) { + CHECK(!used_); lets_.emplace_back(std::make_pair(pv, expr)); return pv; } @@ -71,11 +72,13 @@ class LetList { * * \return the wrapped expr. */ - Expr Get(const Expr& body) const { + Expr Get(const Expr& body) { + CHECK(!used_); Expr ret = body; for (auto rit = lets_.rbegin(); rit != lets_.rend(); ++rit) { ret = LetNode::make(std::get<0>(*rit), std::get<1>(*rit), ret); } + used_ = true; return ret; } @@ -108,6 +111,7 @@ class LetList { private: std::vector > lets_; + bool used_ = false; }; } // namespace relay diff --git a/src/relay/pass/to_anf.cc b/src/relay/pass/to_anf.cc new file mode 100644 index 0000000000000..3880fd16a286f --- /dev/null +++ b/src/relay/pass/to_anf.cc @@ -0,0 +1,410 @@ +/*! + * Copyright (c) 2018 by Contributors + * + * \file to_anf.cc + * + * \brief Turn implicit sharing into observable sharing. + */ +#include +#include +#include "let_list.h" +#include "../../common/arena.h" + +namespace tvm { +namespace relay { + +using common::LinkNode; +using common::LinkedList; + +/* DependencyGraph track input and output of an Expr. + * Additionally, dummy scope is created to model scope. + * It allow us to traverse the graph in reverse order. + */ +class DependencyGraph { + public: + /*! \brief A node in the graph. */ + struct Node { + bool new_scope = false; + LinkedList input; + LinkedList output; + }; + + /*! \brief The node map that maps node to graph */ + std::unordered_map expr_node; + + /*! \brief All the nodes in post DFS order */ + std::vector post_dfs_order; + + /*! + * \brief create a dependency graph. + * \param arena The arena used for data allocation. + * \param body The body of the expression to create a graph. + */ + static DependencyGraph Create(common::Arena* arena, const Expr& body); + + private: + class Creator; +}; + +// Creator of DependencyGraph +class DependencyGraph::Creator : private ExprFunctor { + public: + explicit Creator(common::Arena* arena) + : arena_(arena) {} + + DependencyGraph Create(const Expr& body) { + this->VisitExpr(body); + return std::move(graph_); + } + + private: + /*! \brief allocator of all the internal node object */ + common::Arena* arena_; + // The output. + DependencyGraph graph_; + // Update the message stored at the node. + void Depend(DependencyGraph::Node* parent, const Expr& child) { + VisitExpr(child); + + CHECK_NE(graph_.expr_node.count(child), 0); + + Depend(parent, graph_.expr_node[child]); + } + + void Depend(DependencyGraph::Node* parent, DependencyGraph::Node* child) { + auto* parent_link = arena_->make >(); + parent_link->value = parent; + child->output.Push(parent_link); + + auto* child_link = arena_->make >(); + child_link->value = child; + parent->input.Push(child_link); + } + + std::unordered_set visited_; + + DependencyGraph::Node* NewNode(bool new_scope) { + auto* ret = arena_->make(); + ret->new_scope = new_scope; + return ret; + } + + void VisitExpr(const Expr& e) final { + if (visited_.count(e) == 0) { + if (graph_.expr_node.count(e) == 0) { + graph_.expr_node[e] = NewNode(false); + } + visited_.insert(e); + ExprFunctor::VisitExpr(e); + graph_.post_dfs_order.push_back(graph_.expr_node[e]); + } + } + + void VisitExpr_(const CallNode* c) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(c)]; + Depend(n, c->op); + for (const auto& a : c->args) { + Depend(n, a); + } + } + + void VisitExpr_(const TupleNode* t) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(t)]; + for (const auto& a : t->fields) { + Depend(n, a); + } + } + + void VisitExpr_(const TupleGetItemNode* t) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(t)]; + Depend(n, t->tuple); + } + + void VisitExpr_(const IfNode* i) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(i)]; + DependencyGraph::Node* t = NewNode(true); + DependencyGraph::Node* f = NewNode(true); + Depend(n, i->cond); + Depend(n, t); + Depend(n, f); + Depend(t, i->true_branch); + Depend(f, i->false_branch); + graph_.post_dfs_order.push_back(f); + graph_.post_dfs_order.push_back(t); + } + + void VisitExpr_(const FunctionNode* f) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(f)]; + DependencyGraph::Node* b = NewNode(true); + Depend(n, b); + Depend(b, f->body); + graph_.post_dfs_order.push_back(b); + } + + void VisitExpr_(const LetNode* l) final { + DependencyGraph::Node* n = graph_.expr_node[GetRef(l)]; + DependencyGraph::Node* b = NewNode(true); + Depend(n, b); + Depend(b, l->value); + Depend(b, l->body); + graph_.post_dfs_order.push_back(b); + } + + void VisitExpr_(const VarNode* v) final { } + + void VisitExpr_(const GlobalVarNode* v) final { } + + void VisitExpr_(const ConstantNode* c) final { } + + void VisitExpr_(const OpNode* o) final { } +}; + +DependencyGraph DependencyGraph::Create(common::Arena* arena, const Expr& body) { + return Creator(arena).Create(body); +} + +Expr ToANF(const Expr& e, const Module& m, std::set* gv); + +struct ScopeNode; +using Scope = std::shared_ptr; + +/* Invariant: when parent is null level is 0 + * + * Invariant: when parent is not null level is 1 + parent->level + */ +struct ScopeNode { + size_t level; + Scope parent; + std::shared_ptr ll = std::make_shared(); + explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) { } + ScopeNode() : level(0) { } +}; + +Scope ChildScope(const Scope& s) { + return std::make_shared(s); +} + +Scope LCA(Scope lhs, Scope rhs) { + while (lhs != rhs) { + if (lhs->level > rhs->level) { + lhs = lhs->parent; + } else if (lhs->level < rhs->level) { + rhs = rhs->parent; + } else { + lhs = lhs->parent; + rhs = rhs->parent; + } + } + return lhs; +} + +std::unordered_map CalcScope(const DependencyGraph& dg) { + std::unordered_map expr_scope; + Scope global_scope = std::make_shared(); + for (auto it = dg.post_dfs_order.rbegin(); it != dg.post_dfs_order.rend(); ++it) { + DependencyGraph::Node* n = *it; + auto iit = n->output.head; + Scope s; + if (iit == nullptr) { + s = global_scope; + } else { + s = expr_scope.at(iit->value); + iit = iit->next; + for (; iit != nullptr; iit = iit->next) { + s = LCA(s, expr_scope.at(iit->value)); + } + } + expr_scope.insert({n, n->new_scope ? ChildScope(s) : s}); + } + return expr_scope; +} + +bool IsPrimitiveFunction(const Expr& e) { + return e.as() && Downcast(e)->IsPrimitive(); +} + +class Fill : ExprFunctor { + public: + static Expr ToANF(const Expr& e, + const Module& m, + const DependencyGraph& dg, + std::unordered_map* node_scope, + std::set* gv) { + Fill fi(m, dg, node_scope, gv); + return fi.GetScope(e)->ll->Get(fi.VisitExpr(e)); + } + + private: + Module mod_; + const DependencyGraph& dg_; + std::unordered_map* node_scope_; + std::set* visited_; + std::unordered_map memo; + + Fill(Module mod, + const DependencyGraph& dg, + std::unordered_map* node_scope, + std::set* visited) : + mod_(mod), + dg_(dg), + node_scope_(node_scope), + visited_(visited) { } + + Scope GetScope(const Expr& e) { + return node_scope_->at(dg_.expr_node.at(e)); + } + + Scope GetSubScope(const Expr& e, size_t i) { + DependencyGraph::Node* n = dg_.expr_node.at(e); + auto h = n->input.head; + while (i != 0) { + CHECK(h); + --i; + h = h->next; + } + CHECK(h); + return node_scope_->at(h->value); + } + + Expr VisitExpr(const Expr& e, const Var& v) final { + if (memo.count(e) == 0) { + memo.insert({e, ExprFunctor::VisitExpr(e, v)}); + } + return memo.at(e); + } + + Expr VisitExpr(const Expr& e) { + Var v = VarNode::make(std::string("x"), IncompleteTypeNode::make(TypeVarNode::kType)); + return this->VisitExpr(e, v); + } + + Expr Compound(const Expr& orig, const Expr& now, const Var& v) { + return GetScope(orig)->ll->Push(v, now); + } + + Expr VisitExpr_(const CallNode* c, const Var& v) final { + Expr e = GetRef(c); + std::vector args; + for (const auto& a : c->args) { + args.push_back(VisitExpr(a)); + } + return Compound(e, CallNode::make(VisitExpr(c->op), args, c->attrs, c->type_args), v); + } + + Expr VisitExpr_(const TupleNode* t, const Var& v) final { + Expr e = GetRef(t); + std::vector fields; + for (const auto& a : t->fields) { + fields.push_back(VisitExpr(a)); + } + return Compound(e, TupleNode::make(fields), v); + } + + Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final { + Expr e = GetRef(t); + return Compound(e, TupleGetItemNode::make(VisitExpr(t->tuple), t->index), v); + } + + Expr VisitExpr_(const IfNode* i, const Var& v) final { + Expr e = GetRef(i); + Expr ret = IfNode::make(VisitExpr(i->cond), + GetSubScope(e, 1)->ll->Get(VisitExpr(i->true_branch)), + GetSubScope(e, 2)->ll->Get(VisitExpr(i->false_branch))); + return Compound(e, ret, v); + } + + Expr VisitExpr_(const FunctionNode* f, const Var& v) final { + Expr e = GetRef(f); + Expr ret; + if (IsPrimitiveFunction(e)) { + ret = e; + } else { + ret = FunctionNode::make(f->params, + GetSubScope(e, 0)->ll->Get(VisitExpr(f->body)), + f->ret_type, + f->type_params, + f->attrs); + } + return Compound(e, ret, v); + } + + Expr VisitExpr_(const LetNode* l, const Var& v) final { + Expr e = GetRef(l); + VisitExpr(l->value, l->var); + Expr ret = GetSubScope(e, 0)->ll->Get(VisitExpr(l->body)); + return Compound(e, ret, v); + } + + Expr VisitExpr_(const ConstantNode* c, const Var& v) final { + Expr e = GetRef(c); + return Compound(e, e, v); + } + + Expr VisitExpr_(const VarNode* vn, const Var& v) final { + return GetRef(vn); + } + + Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final { + GlobalVar gv = GetRef(gvn); + if (visited_->count(gv) == 0) { + visited_->insert(gv); + mod_->Update(gv, Downcast(relay::ToANF(mod_->Lookup(gv), mod_, visited_))); + } + return gv; + } + + Expr VisitExpr_(const OpNode* op, const Var& v) final { + return GetRef(op); + } +}; + +Expr ToANFAux(const Expr& e, const Module& m, std::set* gv) { + /* When you lift a lambda, what is inside is also being lift. + * + * So we must determine the scope of the lambda before determining the scope of it's body. + * + * To make this more principled, + * we always determine the scope of parent before determining the scope of children. + * + * So we calculate all the dependency between nodes. + */ + common::Arena arena; + DependencyGraph dg = DependencyGraph::Create(&arena, e); + /* In order to model new subscopes created by lambda, if else and pattern matching, + * we also assign scope to edge as well. + * The scope of an edge is either the parent's scope, or a new subscope of the parent's scope. + * + * So, the scope of the whole expr is global. + * The scope of any subexpr, is the lowest common ancestor of all incoming edge. + * + * Every scope additionally contain a LetList which collect all value of that scope. + * We do an additional pass to fill all the LetList and we are done. + */ + std::unordered_map node_scope = CalcScope(dg); + return Fill::ToANF(e, m, dg, &node_scope, gv); +} + +Expr ToANF(const Expr& e, const Module& m, std::set* gv) { + if (auto f = e.as()) { + return FunctionNode::make(f->params, + ToANFAux(f->body, m, gv), + f->ret_type, + f->type_params, + f->attrs); + } else { + return ToANFAux(e, m, gv); + } +} + +Expr ToANF(const Expr& e, const Module& m) { + std::set gv; + return ToANF(e, m, &gv); +} + +TVM_REGISTER_API("relay._ir_pass.to_anf") +.set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = ToANF(args[0], args[1]); + }); + +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/test_pass_dead_code_elimination.py b/tests/python/relay/test_pass_dead_code_elimination.py index f74aaf74e4748..b88f6500de1a8 100644 --- a/tests/python/relay/test_pass_dead_code_elimination.py +++ b/tests/python/relay/test_pass_dead_code_elimination.py @@ -62,9 +62,9 @@ def test_recursion(): relay.Call(f, [subtract(n, relay.const(1.0)), log(data)])) value = relay.Function([n, data], funcbody, e.float32, []) - orig = relay.Let(f, funcbody, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) + orig = relay.Let(f, value, relay.Call(f, [relay.const(2.0), relay.const(10000.0)])) assert alpha_equal(dead_code_elimination(orig), orig) - assert alpha_equal(dead_code_elimination(relay.Let(f, funcbody, e.three)), e.three) + assert alpha_equal(dead_code_elimination(relay.Let(f, value, e.three)), e.three) def test_op_let(): diff --git a/tests/python/relay/test_to_anf.py b/tests/python/relay/test_to_anf.py new file mode 100644 index 0000000000000..5da7e38a81f51 --- /dev/null +++ b/tests/python/relay/test_to_anf.py @@ -0,0 +1,106 @@ +import numpy as np +import tvm +from tvm import relay +from tvm.relay.ir_pass import to_anf, alpha_equal, infer_type +from tvm.relay import op, create_executor +from tvm.relay.backend.interpreter import Value, TupleValue + + +def check_eval(expr, expected_result, mod=None, rtol=1e-07): + ctx = tvm.context("llvm", 0) + intrp = create_executor(mod=mod, ctx=ctx, target="llvm") + + result = intrp.evaluate(expr) + np.testing.assert_allclose(result.asnumpy(), expected_result, rtol=rtol) + + +def test_explicit_bound(): + x = relay.const(1) + y = op.add(x, x) + z = op.add(y, y) + f = relay.Function([], op.add(z, z)) + assert not "let" in f.astext() # assert the values are implicitly bounded + anf = to_anf(f) + assert "let" in anf.astext() # assert the values are explicitly bounded + check_eval(f(), 8.0) + check_eval(anf(), 8.0) + + +# test that the construction order does not matter, +# and is instead ordered by the scope and by post-dfs ordering. +def test_order(): + z = relay.const(3) + y = relay.const(2) + x = relay.const(1) + val = x + y * z + check_eval(val, 7.0) + anf = infer_type(to_anf(val)) + a = relay.Var('a', relay.IncompleteType()) + b = relay.Var('b', relay.IncompleteType()) + c = relay.Var('c', relay.IncompleteType()) + d = relay.Var('d', relay.IncompleteType()) + e = relay.Var('e', relay.IncompleteType()) + expected_output = e + expected_output = relay.Let(e, a + d, expected_output) + expected_output = relay.Let(d, b * c, expected_output) + expected_output = relay.Let(c, z, expected_output) + expected_output = relay.Let(b, y, expected_output) + expected_output = relay.Let(a, x, expected_output) + expected_output = infer_type(expected_output) + assert alpha_equal(anf, expected_output) + + +def test_if(): + cond = relay.const(True) + x = relay.If(cond, relay.const(2), relay.const(3)) + anf = infer_type(to_anf(x)) + a = relay.Var('a', relay.IncompleteType()) + b = relay.Var('b', relay.IncompleteType()) + c = relay.Var('c', relay.IncompleteType()) + d = relay.Var('d', relay.IncompleteType()) + true_branch = relay.Let(a, relay.const(2), a) + false_branch = relay.Let(b, relay.const(3), b) + expected_output = relay.If(c, true_branch, false_branch) + expected_output = relay.Let(d, expected_output, d) + expected_output = relay.Let(c, cond, expected_output) + expected_output = infer_type(expected_output) + assert alpha_equal(anf, expected_output) + + +# make sure we dont infinite loop. +# it is too large so we wont check for the exact program. +def test_recursion(): + """ + Program: + let sum_twice(n: i32) -> i32 = { + m = (n * 2) + if (n == 0) { + return m; + } else { + return m + sum(n - 1); + } + } + sum_twice(5); + """ + return # cannot be run as fuse_ops need to recursively visit + mod = relay.Module() + i64 = relay.TensorType((), 'int64') + f = relay.GlobalVar("f") + n = relay.Var("n", i64) + m = n * relay.const(2, 'int64') + funcbody = relay.If(relay.equal(n, relay.const(0, 'int64')), + m, + m + f(n - relay.const(1, 'int64'))) + value = relay.Function([n], funcbody, i64, []) + mod[f] = value + check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) + old_f = mod[f] + f = to_anf(f, mod=mod) + check_eval(f(relay.const(5, 'int64')), 30.0, mod=mod) + + +if __name__ == '__main__': + test_explicit_bound() + test_order() + test_if() + test_recursion()