Skip to content

Commit

Permalink
add smart inlining to reduce code bloat
Browse files Browse the repository at this point in the history
  • Loading branch information
joshpoll committed Mar 24, 2019
1 parent 5427163 commit 6f06493
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 29 deletions.
69 changes: 45 additions & 24 deletions src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "type_functor.h"
#include "../pass/dependency_graph.h"
#include "../../lang/attr_functor.h"

namespace tvm {
Expand Down Expand Up @@ -119,7 +120,6 @@ class PrettyPrinter :
explicit PrettyPrinter(bool GNF,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) :
GNF_(GNF),
show_meta_data_(show_meta_data),
annotate_(annotate) {}

Expand Down Expand Up @@ -157,13 +157,18 @@ class PrettyPrinter :
// print in a new scope
doc_stack_.push_back(Doc());
// must print first so doc_stack_.back() reference doesn't become stale
Doc doc = Print(node);
Doc doc = Print(node, false, true);
doc = doc_stack_.back() << doc;
doc_stack_.pop_back();
return doc;
}

Doc PrintFinal(const NodeRef& node) {
if (node.as_derived<ExprNode>()) {
Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr);
}

Doc doc;
doc << PrintScope(node);
if (!meta_.empty()) {
Expand All @@ -181,9 +186,9 @@ class PrettyPrinter :

Doc PrintAttrs(const Attrs& attrs, const Expr& op);

Doc Print(const NodeRef& node, bool meta = false) {
Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
if (node.as_derived<ExprNode>()) {
return PrintExpr(Downcast<Expr>(node), meta);
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<ModuleNode>()) {
Expand Down Expand Up @@ -249,25 +254,38 @@ class PrettyPrinter :
return val;
}

inline bool IsAtomicExpr(const Expr& expr) {
bool IsUnique(const Expr& expr) {
return !(dg_.expr_node.at(expr)->parents.head &&
dg_.expr_node.at(expr)->parents.head->next);
}

bool AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
expr.as<OpNode>() || expr.as<VarNode>();
}

//------------------------------------
// Overload of Expr printing functions
//------------------------------------
Doc PrintExpr(const Expr& expr, bool meta) {
Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
// Exploit memoization to print GNF.
// The first time we visit an expression, we need to allocate a temp var
// for it. Every subsequent time we can just use its assigned variable.
// This works since hashing uses pointer equality.

// determine whether to inline
bool inline_expr = AlwaysInline(expr);
if (try_inline) {
inline_expr |= IsUnique(expr);
}

auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;

Doc printed_expr;
if (meta) {
printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get()));
} else if (GNF_ && expr.as<LetNode>()) {
} else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets
Doc body;
printed_expr << "{";
Expand All @@ -276,8 +294,18 @@ class PrettyPrinter :
} else {
printed_expr = VisitExpr(expr);
}
// we choose to inline atomic exprs
if (GNF_ && !IsAtomicExpr(expr)) {

// add expr to doc
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << "\n";
// Memoization is done in AllocVar.
return memo_[expr];
} else if (inline_expr) {
memo_[expr] = printed_expr;
return printed_expr;
} else {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr;
Expand All @@ -286,18 +314,6 @@ class PrettyPrinter :
}
doc_stack_.back() << "\n";
return temp_var;
} else if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << "\n";
// Memoization is done in AllocVar.
return memo_[expr];
} else {
memo_[expr] = printed_expr;
if (GNF_ && expr.as<CallNode>()) {
printed_expr << PrintOptionalInfo(expr);
}
return printed_expr;
}
}

Expand Down Expand Up @@ -361,8 +377,9 @@ class PrettyPrinter :

Doc VisitExpr_(const LetNode* op) final {
Doc doc;
doc << "let " << AllocVar(op->var) << " = " << Print(op->value) << "\n";
doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << "\n";
// we use a scope here so GNF hoisting doesn't escape too far
// and nested, unique lets are not hoisted
doc << PrintScope(op->body);
return doc;
}
Expand Down Expand Up @@ -391,6 +408,8 @@ class PrettyPrinter :
Doc doc;
int counter = 0;
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);

std::ostringstream os;
if (counter++ != 0) {
doc << "\n";
Expand Down Expand Up @@ -595,8 +614,6 @@ class PrettyPrinter :
}

private:
/*! \brief Whether to use GNF. */
bool GNF_;
/*! \brief Whether to print meta data. */
bool show_meta_data_;
/*! \brief additional comment function */
Expand All @@ -613,6 +630,10 @@ class PrettyPrinter :
TextMetaDataContext meta_;
/*! \brief counter of temporary variable */
size_t temp_var_counter_{0};
/*! \brief arena for dependency graph */
common::Arena arena_;
/*! \brief dependency graph of the expr */
DependencyGraph dg_;
class AttrPrinter;
friend class AttrPrinter;
};
Expand Down
22 changes: 17 additions & 5 deletions tests/python/relay/test_ir_text_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def test_env():
text = env.astext()
assert "def @myf" in text
assert "def @myf" in str(env)
assert "%1 = add(%0, %0) // ty=float32" in text
assert "%1 = add(%0, %0) // ty=float32" in str(env)
assert "add(%0, %0)" in text
assert "add(%0, %0)" in str(env)
show(env.astext(annotate=lambda x: str(x.checked_type.dtype)))
show(text)

Expand Down Expand Up @@ -96,7 +96,7 @@ def test_let_if_scope():

f = relay.Function([x, y, cond], result)
text = f.astext()
assert text.count("{") == 6
assert text.count("{") == 4
assert "%cond: bool" in text
show(f.astext())

Expand Down Expand Up @@ -164,8 +164,19 @@ def test_call_node_order():
"%2 = fn (%x) {\n"
" %x\n"
"}\n"
"%3 = %2(%1)\n"
"%3")
"%2(%1)")

def test_let_inlining():
tup = relay.Tuple([relay.const(0), relay.const(0)])
x = relay.var("x")
assert relay.Let(x, tup, tup).astext() == SEMVER + \
("%0 = (0, 0)\n"
"let %x = %0\n"
"%0")

assert relay.Let(x, tup, x).astext() == SEMVER + \
("let %x = (0, 0)\n"
"%x")

if __name__ == "__main__":
do_print[0] = True
Expand All @@ -185,3 +196,4 @@ def test_call_node_order():
test_let_if_scope()
test_variable_name()
test_call_node_order()
test_let_inlining()

0 comments on commit 6f06493

Please sign in to comment.