Skip to content

Commit

Permalink
[RELAY][PASS] Bind, FoldConstant (#2100)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Nov 14, 2018
1 parent 1b86373 commit b252160
Show file tree
Hide file tree
Showing 22 changed files with 648 additions and 181 deletions.
11 changes: 11 additions & 0 deletions include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,17 @@ class ExprMutator
std::unordered_map<Expr, Expr, NodeHash, NodeEqual> memo_;
};

/*
* \brief Bind function parameters or free variables.
*
* Parameter binding can only happen if expr is a Function.
* binds cannot change internal arguments of internal functions.
*
* \param expr The function to be binded.
* \param binds The map of arguments to
*/
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_FUNCTOR_H_
10 changes: 10 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ enum OpPatternKind {
/*! \brief the operator pattern */
using TOpPattern = int;

/*!
* \brief Whether operator is stateful or contain internal state.
*
* All the primitive ops we registered so far are pure.
* This attribute is left for potential future compatible reasons.
* We can always work around the stateful ops by adding an additional
* handle argument and return it.
*/
using TOpIsStateful = bool;

/*!
* \brief Computation description interface.
*
Expand Down
16 changes: 16 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ tvm::Array<TypeVar> FreeTypeVars(const Expr& expr);
*/
Expr DeadCodeElimination(const Expr& e);

/*!
* \brief Fold constant expressions.
* \param expr the expression to be optimized.
* \return The optimized expression.
*/
Expr FoldConstant(const Expr& expr);

/*!
* \brief Fuse operations into expr into seperate functions.
* \param expr The expression.
* \param fuse_opt_level Optimization level.
* \return The optimized expression.
*/
Expr FuseOps(const Expr& expr, int fuse_opt_level);


/*! \brief A hashing structure in the style of std::hash. */
struct StructuralHash {
/*! \brief Hash a Relay type.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
# helper functions
var = expr.var
const = expr.const

bind = expr.bind

# pylint: disable=unused-argument
@register_func("relay.debug")
Expand Down
14 changes: 11 additions & 3 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(self, mod, target):
self.target = target
self.nodes = []
self.var_map = {}
self.params = {}
self.compile_engine = compile_engine.get()
self.lowered_funcs = set()
self._name_map = {}
Expand Down Expand Up @@ -162,8 +163,12 @@ def visit_tuple_getitem(self, op):
assert isinstance(vtuple, tuple)
return vtuple[op.index]

def visit_constant(self, _):
raise RuntimeError("constant not supported")
def visit_constant(self, op):
index = len(self.params)
name = "p%d" % index
self.params[name] = op.data
node = InputNode(name, {})
return self.add_node(node, op.checked_type)

def visit_function(self, _):
raise RuntimeError("function not supported")
Expand Down Expand Up @@ -312,6 +317,9 @@ def codegen(self, func):
lowered_funcs : List[tvm.LoweredFunc]
The lowered functions.
params : Dict[str, tvm.nd.NDArray]
Additional constant parameters.
"""
# First we convert all the parameters into input nodes.
for param in func.params:
Expand All @@ -324,7 +332,7 @@ def codegen(self, func):
self.heads = self.visit(func.body)
graph_json = self._get_json()
lowered_funcs = list(self.lowered_funcs)
return graph_json, lowered_funcs
return graph_json, lowered_funcs, self.params

def _get_unique_name(self, name):
if name not in self._name_map:
Expand Down
56 changes: 45 additions & 11 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
from .. import nd as _nd, target as _target, autotvm
from ..contrib import graph_runtime as _graph_rt
from . import ir_pass
from . import expr
from .backend import interpreter as _interpreter
from .backend import graph_runtime_codegen as _graph_gen

# List of optimization pass and level when switch on
OPT_PASS_LEVEL = {
"SimplifyInference": 0,
"OpFusion": 1,
"FoldConstant": 2,
"FoldScaleAxis": 3,
}

Expand Down Expand Up @@ -95,22 +97,50 @@ def build_config(**kwargs):
return BuildConfig(**kwargs)


def optimize(func):
def _bind_params_by_name(func, params):
"""Bind parameters of function by its name."""
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = expr.const(v)
return expr.bind(func, bind_dict)


def optimize(func, params=None):
"""Perform target invariant optimizations.
Parameters
----------
func : tvm.relay.Function
The input to optimization.
params : Optional[Dict[str, tvm.nd.NDArray]]
Input parameters to the graph that do not change
during inference time. used for constant folding.
Returns
-------
opt_func : tvm.relay.Function
The optimized version of the function.
"""
cfg = BuildConfig.current

if cfg.pass_enabled("FoldScaleAxis"):
# bind expressions
if params:
func = _bind_params_by_name(func, params)

if cfg.pass_enabled("SimplifyInference"):
func = ir_pass.infer_type(func)
func = ir_pass.simplify_inference(func)

Expand All @@ -119,6 +149,10 @@ def optimize(func):
func = ir_pass.backward_fold_scale_axis(func)
func = ir_pass.infer_type(func)
func = ir_pass.forward_fold_scale_axis(func)

if cfg.pass_enabled("FoldConstant"):
func = ir_pass.fold_constant(func)

return func


Expand Down Expand Up @@ -147,8 +181,7 @@ def build(func,
params : dict of str to NDArray
Input parameters to the graph that do not change
during inference time. Used for pre-compute
folding optimization.
during inference time. Used for constant folding.
Returns
-------
Expand Down Expand Up @@ -176,14 +209,14 @@ def build(func,
cfg = BuildConfig.current

with tophub_context:
func = optimize(func)
func = optimize(func, params)
# Fuse ops before running code gen
func = ir_pass.infer_type(func)
func = ir_pass.fuse_ops(func, cfg.opt_level)
# Graph code generation
func = ir_pass.infer_type(func)
graph_gen = _graph_gen.GraphRuntimeCodegen(mod=None, target=target)
graph_json, lowered_funcs = graph_gen.codegen(func)
graph_json, lowered_funcs, params = graph_gen.codegen(func)
mod = _tvm_build_module(lowered_funcs, target=target, target_host=target_host)
return graph_json, mod, params

Expand All @@ -210,21 +243,22 @@ def __init__(self, mod, ctx, target):
self.target = target

def _make_executor(self, func):
graph_json, mod, params = build(func, target=self.target)
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
if params:
gmodule.set_input(*params)
def _graph_wrapper(*args):
graph_json, mod, params = build(func, target=self.target)
assert params is None
gmodule = _graph_rt.create(graph_json, mod, self.ctx)
# Create map of inputs.
for i, arg in enumerate(args):
gmodule.set_input(i, arg)
# Run the module, and fetch the output.
gmodule.run()
return gmodule.get_output(0)
# make a copy so multiple invocation won't hurt perf.
return gmodule.get_output(0).copyto(_nd.cpu(0))

return _graph_wrapper



def create_executor(kind="debug",
mod=None,
ctx=None,
Expand Down
22 changes: 22 additions & 0 deletions python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as _np
from .base import RelayNode, register_relay_node
from . import _make
from . import _expr
from . import ty as _ty
from .._ffi import base as _base
from .. import nd as _nd
Expand Down Expand Up @@ -577,3 +578,24 @@ def const(value, dtype=None):
if not isinstance(value, _nd.NDArray):
raise ValueError("value has to be scalar or NDArray")
return Constant(value)


def bind(expr, binds):
"""Bind an free variables in expr or function arguments.
We can bind parameters expr if it is a function.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
binds : Union[Map[tvm.relay.Var, tvm.relay.Expr], Map[str, tvm.relay.Expr]]
The specific bindings.
Returns
-------
result : tvm.relay.Expr
The expression or function after binding.
"""
return _expr.Bind(expr, binds)
16 changes: 16 additions & 0 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,22 @@ def structural_hash(value):
raise TypeError(msg)


def fold_constant(expr):
"""Fold the constant expression in expr.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
transformed_expr : tvm.relay.Expr
The transformed expression.
"""
return _ir_pass.FoldConstant(expr)


def fuse_ops(expr, opt_level=1):
"""Fuse operators in expr together.
Expand Down
71 changes: 70 additions & 1 deletion src/relay/ir/expr_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
* ExprMutator uses memoization and self return in order to amortize
* the cost of using functional updates.
*/

#include <tvm/relay/expr_functor.h>
#include "type_functor.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -228,5 +228,74 @@ void ExprVisitor::VisitExpr_(const TupleGetItemNode* op) {

void ExprVisitor::VisitType(const Type& t) { return; }

// Implement bind.
class ExprBinder : public ExprMutator {
public:
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
: args_map_(args_map) {
}

Expr VisitExpr_(const LetNode* op) final {
CHECK(!args_map_.count(op->var))
<< "Cannot bind an internel variable in let";
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const FunctionNode* op) final {
for (Var param : op->params) {
CHECK(!args_map_.count(param))
<< "Cannnot bind an internal function parameter";
}
return ExprMutator::VisitExpr_(op);
}

Expr VisitExpr_(const VarNode* op) final {
auto id = GetRef<Var>(op);
auto it = args_map_.find(id);
if (it != args_map_.end()) {
return (*it).second;
} else {
return id;
}
}

private:
const tvm::Map<Var, Expr>& args_map_;
};

Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
if (const FunctionNode* func = expr.as<FunctionNode>()) {
Expr new_body = ExprBinder(args_map).Mutate(func->body);
Array<Var> new_params;
for (Var param : func->params) {
if (!args_map.count(param)) {
new_params.push_back(param);
}
}
if (new_body.same_as(func->body) &&
new_params.size() == func->params.size()) {
return expr;
}
return FunctionNode::make(new_params,
new_body,
func->ret_type,
func->type_params,
func->attrs);
} else {
return ExprBinder(args_map).Mutate(expr);
}
}


TVM_REGISTER_API("relay._expr.Bind")
.set_body([](TVMArgs args, TVMRetValue* ret) {
NodeRef input = args[0];
if (input->derived_from<ExprNode>()) {
*ret = Bind(Downcast<Expr>(input), args[1]);
} else {
CHECK(input->derived_from<TypeNode>());
*ret = Bind(Downcast<Type>(input), args[1]);
}
});
} // namespace relay
} // namespace tvm
2 changes: 0 additions & 2 deletions src/relay/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
#include <memory>
#include <mutex>

#include "./../pass/type_subst.h"

namespace dmlc {
// enable registry
DMLC_REGISTRY_ENABLE(::tvm::relay::OpRegistry);
Expand Down
Loading

0 comments on commit b252160

Please sign in to comment.