Skip to content

Commit

Permalink
[RELAY] Basic block normal form (#6152)
Browse files Browse the repository at this point in the history
* initial commit

* refactor utils

* add util

* revert anf test

* update test

* fix logging

* fix scope bug

* complete tests

* remove logging

* revert refactoring

* add one more test case

* fix missing var binding

* fix test

* fix lint

* fix lint

* fix clang-format

* fix lint

* fix lint

* commit missing code

* add analysis api

* fix lint

* fix lint

* lint

* add test for func

* address CR

* fix typo

* fix return type

* fix lint

* refactor classes

* fix lint

* remove prints

* address comments

Co-authored-by: Ubuntu <[email protected]>
  • Loading branch information
eric-haibin-lin and Ubuntu authored Aug 4, 2020
1 parent 90bde33 commit b6db7e3
Show file tree
Hide file tree
Showing 13 changed files with 1,098 additions and 147 deletions.
9 changes: 9 additions & 0 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
*/
TVM_DLL bool ConstantCheck(const Expr& e);

/*!
* \brief Check whether an expression is in the basic block normal form.
*
* \param e the expression.
*
* \return whether the expression is in the basic block normal form.
*/
TVM_DLL bool BasicBlockNormalFormCheck(const Expr& e);

/*!
* \brief Check that each Var is only bound once.
*
Expand Down
15 changes: 15 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,21 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
*/
TVM_DLL Pass RewriteAnnotatedOps(int fallback_device);

/*!
* \brief Turn an expression to Basic Block Normal Form.
*
* We define a block as a group of expressions implied by the scope structure.
*
* Each graph node can only belong to a single block.
*
* For any value that is being used in multiple blocks, it has to be referred
* by a Var which is defined in a block, whose scope is the least common ancestor
* of blocks this value is used.
*
* \return The pass.
*/
TVM_DLL Pass ToBasicBlockNormalForm();

/*!
* \brief turn a dataflow graph into Administrative Normal Form, or A-Normal Form (ANF).
*
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,21 @@ def check_constant(expr):
"""
return _ffi_api.check_constant(expr)

def check_basic_block_normal_form(expr):
"""Check whether an expression is in the basic block form
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
result : bool
Whether the expression is in the basic block form.
"""
return _ffi_api.check_basic_block_normal_form(expr)


def free_vars(expr):
"""Get free Vars from expression expr in Post DFS order.
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,21 @@ def ToANormalForm():
"""
return _ffi_api.ToANormalForm()

def ToBasicBlockNormalForm():
"""Turn an expression to Basic Block Normal Form.
We define a block as a group of expressions implied by the scope structure.
Each graph node can only belong to a single block.
For any value that is being used in multiple blocks, it has to be referred
by a Var which is defined in a block, whose scope is the least common ancestor
of blocks this value is used.
Returns
-------
ret: tvm.transform.Pass
The registered pass that transforms an expression into Basic Block Normal Form.
"""
return _ffi_api.ToBasicBlockNormalForm()


def ToCPS(expr, mod=None):
"""
Expand Down
4 changes: 4 additions & 0 deletions src/relay/analysis/dependency_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(f)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
for (const auto& p : f->params) {
Depend(b, p);
}
Depend(b, f->body);
graph_.post_dfs_order.push_back(b);
}
Expand All @@ -145,6 +148,7 @@ class DependencyGraph::Creator : private ExprFunctor<void(const Expr& e)> {
DependencyGraph::Node* n = graph_.expr_node[GetRef<Expr>(l)];
DependencyGraph::Node* b = NewNode(true);
Depend(n, b);
Depend(b, l->var);
Depend(b, l->value);
Depend(b, l->body);
graph_.post_dfs_order.push_back(b);
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class RelayBuildModule : public runtime::ModuleNode {
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
pass_seqs.push_back(transform::ToBasicBlockNormalForm());

// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,7 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
pass_seqs.push_back(transform::ToBasicBlockNormalForm());
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());

Expand Down
6 changes: 6 additions & 0 deletions src/relay/transforms/let_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,12 @@ class LetList {
return ret;
}

/*! \brief get the number of let bindings in the let list.
*
* \return the let list size.
*/
size_t size() const { return lets_.size(); }

/*! \brief generate an LetList and wrap the result automatically.
*
* \param f a function that generate the unwrapped Expr.
Expand Down
88 changes: 88 additions & 0 deletions src/relay/transforms/pass_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@

#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include "../analysis/dependency_graph.h"
#include "let_list.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -184,6 +189,89 @@ struct TreeBranchNode : TreeNode<ConditionObjectPtr> {
~TreeBranchNode() {}
};

struct ScopeNode;
using Scope = std::shared_ptr<ScopeNode>;
using NodeScopeMap = std::unordered_map<DependencyGraph::Node*, Scope>;
using ExprSet = std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual>;

/* Invariant: when parent is null level is 0
* Invariant: when parent is not null level is 1 + parent->level
*/
struct ScopeNode {
// the level of the scope
size_t level;
// the parent scope
Scope parent;
// the corresponding let list which holds all let bindings in the scope
std::shared_ptr<LetList> let_list = std::make_shared<LetList>();
explicit ScopeNode(const Scope& parent) : level(1 + parent->level), parent(parent) {}
ScopeNode() : level(0) {}
};

/*! \brief Calculate the scope of nodes in the dependency graph by least common ancestor.
*
* \param dg the input dependency graph
* \param expr_scope the output node -> scope mapping for all nodes.
* \param lifted_exprs the output set of expressions whose scope is lifted due to dependency
*/
std::pair<NodeScopeMap, ExprSet> CalcScope(const DependencyGraph& dg);

/*! \brief find the least common ancestor of lhs scope and rhs scope.
*/
Scope LCA(Scope lhs, Scope rhs);

/* Special care is needed to handle local recursion.
* Fill additionally take a (possibly null) Var argument,
* If it is not null, Fill is required to bind the transformed result to that var.
*/
class Fill : ExprFunctor<Expr(const Expr&, const Var&)> {
public:
static Expr ToANormalForm(const Expr& e, const DependencyGraph& dg, NodeScopeMap* node_scope);

// For basic block normal form, bind expressions only if the original expression's
// scope should be lifted
static Expr ToBasicBlockNormalForm(const Expr& e, const DependencyGraph& dg,
NodeScopeMap* node_scope, ExprSet* lifted);

private:
const DependencyGraph& dg_;
NodeScopeMap* node_scope_ = nullptr;
std::unordered_map<Expr, Expr, ObjectPtrHash, ObjectPtrEqual> memo;
// a set of Expressions to include for let bindings. If set to nullptr
// all Exprs will be pushed to the let list.
ExprSet* include_set_ = nullptr;

Fill(const DependencyGraph& dg, NodeScopeMap* node_scope, ExprSet* include_set)
: dg_(dg), node_scope_(node_scope), include_set_(include_set) {}

Scope GetScope(const Expr& e);
Scope GetSubScope(const Expr& e, size_t i);

Expr VisitExpr(const Expr& e, const Var& v) final;
Expr VisitExpr(const Expr& e);

Expr Atomic(const Expr& e, const Var& v);
// Bind expression `now` to var `v` if the original expression is in the include set, or if
// v is already defined (e.g. coming from a Let expression). Otherwise return `now` directly.
Expr Compound(const Expr& orig, const Expr& now, const Var& v);

Expr VisitExpr_(const CallNode* c, const Var& v) final;
Expr VisitExpr_(const TupleNode* t, const Var& v) final;
Expr VisitExpr_(const TupleGetItemNode* t, const Var& v) final;
Expr VisitExpr_(const RefCreateNode* r, const Var& v) final;
Expr VisitExpr_(const RefReadNode* r, const Var& v) final;
Expr VisitExpr_(const RefWriteNode* r, const Var& v) final;
Expr VisitExpr_(const IfNode* i, const Var& v) final;
Expr VisitExpr_(const FunctionNode* f, const Var& v) final;
Expr VisitExpr_(const LetNode* l, const Var& v) final;
Expr VisitExpr_(const ConstantNode* c, const Var& v) final;
Expr VisitExpr_(const VarNode* vn, const Var& v) final;
Expr VisitExpr_(const GlobalVarNode* gvn, const Var& v) final;
Expr VisitExpr_(const OpNode* op, const Var& v) final;
Expr VisitExpr_(const ConstructorNode* c, const Var& v) final;
Expr VisitExpr_(const MatchNode* m, const Var& v) final;
};

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_TRANSFORMS_PASS_UTIL_H_
Loading

0 comments on commit b6db7e3

Please sign in to comment.