Skip to content

Commit

Permalink
[Relay][Pass] Merge two consecutive reshape ops (apache#6052)
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon authored and Trevor Morris committed Aug 26, 2020
1 parent c006a9f commit 8f00b1e
Show file tree
Hide file tree
Showing 13 changed files with 320 additions and 42 deletions.
16 changes: 11 additions & 5 deletions include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ class DFPatternCallback;
class DFPatternCallbackNode : public Object {
public:
/*! \brief Pattern this callback matches */
DFPattern pattern_;
DFPattern pattern;
/*! \brief Function to call when finding a matched expression */
PackedFunc function_;
PackedFunc function;
/*! \brief Require InferType to be run before the callback */
bool require_type;

void VisitAttrs(tvm::AttrVisitor* v) {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("require_type", &require_type);
}

static constexpr const char* _type_key = "DFPatternCallbackNode";
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
Expand All @@ -58,7 +63,7 @@ class DFPatternCallbackNode : public Object {
*/
class DFPatternCallback : public ObjectRef {
public:
TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback);
TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type);
TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode);
};

Expand All @@ -77,11 +82,12 @@ bool MatchPattern(DFPattern pattern, Expr expr);
*
* \param callbacks An array of DFPatternCallback Nodes
* \param expr The expression to rewrite
* \param mod The module that associates with the expr
*
* \return Return An Expr with every match of the pattern inside the callbacks rewritten by the
* functions inside the callbacks
*/
Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr);
Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr, IRModule mod = IRModule());

/*!
* \brief Partition all matches of a DFPattern inside an Expr into separate Function calls
Expand Down
7 changes: 7 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ TVM_DLL Pass Inline();
*/
TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);

/*!
* \brief Simplify the Relay expression.
*
* \return The pass.
*/
TVM_DLL Pass SimplifyExpr();

} // namespace transform

/*!
Expand Down
32 changes: 22 additions & 10 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tvm.relay.expr import RelayExpr as Expr

from ... import _ffi as tvm_ffi
from ... import ir as _ir
from ...ir import make_node
from ...ir.base import Node
from ...runtime import Object
Expand Down Expand Up @@ -687,7 +688,15 @@ class DFPatternCallback:
the callback returns.
Users are expect to inherit from this class and provide a "self.pattern" to match
Parameters
----------
require_type: bool
Whether InferType is required to be run before the callback.
"""
def __init__(self, require_type=False):
self.pattern = None
self.require_type = require_type

def rewrite(self, expr: Expr) -> Expr:
"""
Expand Down Expand Up @@ -727,11 +736,11 @@ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Exp

class _DFPatternCallback(Object):
"""C++ implemenation"""
def __init__(self, pattern, callback):
self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback)
def __init__(self, pattern, callback, require_type):
self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type)


def rewrite(callbacks, expr: Expr) -> Expr:
def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr:
"""
Rewrite expression with the given callbacks.
Expand All @@ -741,20 +750,23 @@ def rewrite(callbacks, expr: Expr) -> Expr:
The input callback or list of callbacks.
expr : tvm.relay.Expr
The expression to rewrite.
mod : Optional[tvm.ir.IRModule]
The module that associates with the expression.
Returns
-------
result : tvm.relay.Expr
The Expression with matched subgraphs rewritten by the callbacks.
"""
if isinstance(callbacks, DFPatternCallback):
tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)]
else:
tmp = []
for callback in callbacks:
tmp.append(_DFPatternCallback(callback.pattern, callback.callback))
if mod is None:
mod = _ir.IRModule()
callbacks = [callbacks] if isinstance(callbacks, DFPatternCallback) else callbacks
tmp = []
for callback in callbacks:
assert callback.pattern is not None
tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type))

return ffi.rewrite(tmp, expr)
return ffi.rewrite(tmp, expr, mod)


def partition(pattern: "DFPattern",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
_reg.register_injective_schedule("take")
_reg.register_injective_schedule("transpose")
_reg.register_injective_schedule("stack")
_reg.register_injective_schedule("_contrib_reverse_reshape")
_reg.register_injective_schedule("contrib_reverse_reshape")
_reg.register_injective_schedule("gather")
_reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ def reverse_reshape(data, newshape):
"""
if isinstance(newshape, int):
newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape))
return _make.contrib_reverse_reshape(data, list(newshape))


def gather(data, axis, indices):
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ def DenseToSparse(weight_name, weight_shape):
"""
return _ffi_api.DenseToSparse(weight_name, weight_shape)


def SimplifyFCTranspose(target_weight_name):
"""
Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
Expand All @@ -926,3 +927,15 @@ def SimplifyFCTranspose(target_weight_name):
The registered SimplifyFCTranspose pass.
"""
return _ffi_api.SimplifyFCTranspose(target_weight_name)


def SimplifyExpr():
"""
Simplify the Relay expression, including merging consecutive reshapes.
Returns
-------
ret : tvm.transform.Pass
The registered SimplifyExpr pass.
"""
return _ffi_api.SimplifyExpr()
1 change: 1 addition & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class RelayBuildModule : public runtime::ModuleNode {
}
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -945,10 +945,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
*rv = false;
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::InlinePrimitives());

pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
Expand All @@ -959,6 +961,8 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::AlterOpLayout());
}

// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());

pass_seqs.push_back(transform::FuseOps());
Expand Down
Loading

0 comments on commit 8f00b1e

Please sign in to comment.