Skip to content

Commit

Permalink
[Relay][Text Format] Pretty Printer Smart Inlining (apache#2881)
Browse files Browse the repository at this point in the history
  • Loading branch information
joshpoll authored and Wei Chen committed May 13, 2019
1 parent 5b5e0fb commit 7561043
Show file tree
Hide file tree
Showing 9 changed files with 324 additions and 290 deletions.
33 changes: 0 additions & 33 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,39 +925,6 @@ def eliminate_common_subexpr(expr, fskip=None):
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)


def pass_debug_print(ast, show_meta_data=True, annotate=None, gnf=True):
"""
THIS SHOULD BE USED ONLY FOR DEBUGGING, NOT AS AN INTERCHANGE FORMAT!
USE `.astext()` INSTEAD!
A version of the pretty printer intended for debugging passes. Contains
advanced printing options.
Parameters
----------
ast : Union[relay.Expr, relay.Module, relay.Type]
The relay fragment to be turned into text.
show_meta_data : bool
Whether to include meta data section in the text
if there is meta data.
annotate: Optional[relay.Expr->str]
Optional annotate function to provide additional
information in the comment block.
gnf : bool
Whether to print in GNF. If it is disabled, pointers are left implicit.
Returns
-------
text : str
A text representation of `ast`.
"""
return _ir_pass.pass_debug_print(ast, show_meta_data, annotate, gnf)


def partial_evaluate(expr):
"""
Evaluate the static fragment of the code.
Expand Down
112 changes: 63 additions & 49 deletions src/relay/ir/pretty_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,22 @@
* \file pretty_printer.cc
* \brief Pretty printer for Relay programs
* Supports ANF, GNF, and metadata.
*
* Inlining heuristics:
* - Always inline:
* - GlobalVar
* - Constant
* - Op
* - Var
* - Otherwise, inline if the node is at the end of a scope and is used at most once.
*/

#include <tvm/relay/expr_functor.h>
#include <tvm/relay/module.h>
#include <tvm/relay/pattern_functor.h>
#include "doc.h"
#include "type_functor.h"
#include "../pass/dependency_graph.h"
#include "../../lang/attr_functor.h"

namespace tvm {
Expand Down Expand Up @@ -135,10 +145,8 @@ class PrettyPrinter :
public TypeFunctor<Doc(const Type&)>,
public AttrFunctor<Doc(const NodeRef&)> {
public:
explicit PrettyPrinter(bool GNF,
bool show_meta_data,
explicit PrettyPrinter(bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) :
GNF_(GNF),
show_meta_data_(show_meta_data),
annotate_(annotate) {}

Expand All @@ -150,10 +158,9 @@ class PrettyPrinter :
Doc doc;
// additional information in comment.
if (annotate_ != nullptr) {
return doc << " // " << annotate_(expr);
return doc << " /* " << annotate_(expr) << " */";
} else if (expr->checked_type_.defined()) {
doc << " // ty=";
return doc << Print(expr->checked_type());
return doc << " /* ty=" << Print(expr->checked_type()) << " */";
} else {
return doc;
}
Expand All @@ -176,13 +183,18 @@ class PrettyPrinter :
// print in a new scope
doc_stack_.push_back(Doc());
// must print first so doc_stack_.back() reference doesn't become stale
Doc doc = Print(node);
Doc doc = Print(node, false, true);
doc = doc_stack_.back() << doc;
doc_stack_.pop_back();
return doc;
}

Doc PrintFinal(const NodeRef& node) {
if (node.as_derived<ExprNode>()) {
Expr expr = Downcast<Expr>(node);
dg_ = DependencyGraph::Create(&arena_, expr);
}

Doc doc;
doc << PrintScope(node);
if (!meta_.empty()) {
Expand All @@ -200,9 +212,9 @@ class PrettyPrinter :

Doc PrintAttrs(const Attrs& attrs, const Expr& op);

Doc Print(const NodeRef& node, bool meta = false) {
Doc Print(const NodeRef& node, bool meta = false, bool try_inline = false) {
if (node.as_derived<ExprNode>()) {
return PrintExpr(Downcast<Expr>(node), meta);
return PrintExpr(Downcast<Expr>(node), meta, try_inline);
} else if (node.as_derived<TypeNode>()) {
return PrintType(Downcast<Type>(node), meta);
} else if (node.as_derived<ModuleNode>()) {
Expand Down Expand Up @@ -308,25 +320,38 @@ class PrettyPrinter :
return val;
}

inline bool IsAtomicExpr(const Expr& expr) {
bool IsUnique(const Expr& expr) {
return !(dg_.expr_node.at(expr)->parents.head &&
dg_.expr_node.at(expr)->parents.head->next);
}

bool AlwaysInline(const Expr& expr) {
return expr.as<GlobalVarNode>() || expr.as<ConstantNode>() ||
expr.as<OpNode>() || expr.as<VarNode>();
}

//------------------------------------
// Overload of Expr printing functions
//------------------------------------
Doc PrintExpr(const Expr& expr, bool meta) {
Doc PrintExpr(const Expr& expr, bool meta, bool try_inline) {
// Exploit memoization to print GNF.
// The first time we visit an expression, we need to allocate a temp var
// for it. Every subsequent time we can just use its assigned variable.
// This works since hashing uses pointer equality.

// determine whether to inline
bool inline_expr = AlwaysInline(expr);
if (try_inline) {
inline_expr |= IsUnique(expr);
}

auto it = memo_.find(expr);
if (it != memo_.end()) return it->second;

Doc printed_expr;
if (meta) {
printed_expr = meta_.GetMetaNode(GetRef<NodeRef>(expr.get()));
} else if (GNF_ && expr.as<LetNode>()) {
} else if (!inline_expr && expr.as<LetNode>()) {
// wrap GNFed let in brackets
Doc body;
printed_expr << "{";
Expand All @@ -335,28 +360,26 @@ class PrettyPrinter :
} else {
printed_expr = VisitExpr(expr);
}
// we choose to inline atomic exprs
if (GNF_ && !IsAtomicExpr(expr)) {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr;
if (expr.as<CallNode>()) {
doc_stack_.back() << PrintOptionalInfo(expr);
}
doc_stack_.back() << "\n";
return temp_var;
} else if (expr.as<VarNode>()) {

if (expr.as<CallNode>()) {
printed_expr << PrintOptionalInfo(expr);
}

// add expr to doc
if (expr.as<VarNode>()) {
// This is our first time visiting the var and we hit the VarNode case
// in the visitor. Thus the variable is free.
doc_stack_.back() << "free_var " << printed_expr << "\n";
// Memoization is done in AllocVar.
return memo_[expr];
} else {
} else if (inline_expr) {
memo_[expr] = printed_expr;
if (GNF_ && expr.as<CallNode>()) {
printed_expr << PrintOptionalInfo(expr);
}
return printed_expr;
} else {
Doc temp_var = AllocTemp();
memo_[expr] = temp_var;
doc_stack_.back() << temp_var << " = " << printed_expr << "\n";
return temp_var;
}
}

Expand Down Expand Up @@ -420,8 +443,9 @@ class PrettyPrinter :

Doc VisitExpr_(const LetNode* op) final {
Doc doc;
doc << "let " << AllocVar(op->var) << " = " << Print(op->value) << "\n";
doc << "let " << AllocVar(op->var) << " = " << Print(op->value, false, true) << "\n";
// we use a scope here so GNF hoisting doesn't escape too far
// and nested, unique lets are not hoisted
doc << PrintScope(op->body);
return doc;
}
Expand Down Expand Up @@ -456,6 +480,8 @@ class PrettyPrinter :
Doc doc;
int counter = 0;
for (const auto& kv : mod->functions) {
dg_ = DependencyGraph::Create(&arena_, kv.second);

std::ostringstream os;
if (counter++ != 0) {
doc << "\n";
Expand Down Expand Up @@ -664,8 +690,6 @@ class PrettyPrinter :
}

private:
/*! \brief Whether to use GNF. */
bool GNF_;
/*! \brief Whether to print meta data. */
bool show_meta_data_;
/*! \brief additional comment function */
Expand All @@ -682,6 +706,10 @@ class PrettyPrinter :
TextMetaDataContext meta_;
/*! \brief counter of temporary variable */
size_t temp_var_counter_{0};
/*! \brief arena for dependency graph */
common::Arena arena_;
/*! \brief dependency graph of the expr */
DependencyGraph dg_;
class AttrPrinter;
friend class AttrPrinter;
};
Expand Down Expand Up @@ -751,37 +779,23 @@ Doc PrettyPrinter::PrintAttrs(const Attrs& attrs, const Expr& op) {

std::string PrettyPrint_(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate,
bool gnf) {
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
Doc doc;
doc << "v0.0.1" << "\n"
<< PrettyPrinter(gnf, show_meta_data, annotate).PrintFinal(node);
<< PrettyPrinter(show_meta_data, annotate).PrintFinal(node);
return doc.str();
}

std::string AsText(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return PrettyPrint_(node, show_meta_data, annotate, true);
}

std::string PassDebugPrint(const NodeRef& node,
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate,
bool gnf) {
return PrettyPrint_(node, show_meta_data, annotate, gnf);
bool show_meta_data,
runtime::TypedPackedFunc<std::string(Expr)> annotate) {
return PrettyPrint_(node, show_meta_data, annotate);
}

TVM_REGISTER_API("relay._expr.AsText")
.set_body_typed<std::string(const NodeRef&,
bool,
runtime::TypedPackedFunc<std::string(Expr)>)>(AsText);

TVM_REGISTER_API("relay._ir_pass.pass_debug_print")
.set_body_typed<std::string(const NodeRef&,
bool,
runtime::TypedPackedFunc<std::string(Expr)>,
bool)>(PassDebugPrint);

} // namespace relay
} // namespace tvm
Loading

0 comments on commit 7561043

Please sign in to comment.