Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Merge two consecutive reshape ops #6052

Merged
merged 13 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ class DFPatternCallback;
class DFPatternCallbackNode : public Object {
public:
/*! \brief Pattern this callback matches */
DFPattern pattern_;
DFPattern pattern;
/*! \brief Function to call when finding a matched expression */
PackedFunc function_;
PackedFunc function;
/*! \brief Require InferType to be run before the callback */
bool require_type;
Comment on lines +45 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://tvm.apache.org/docs/contribute/code_guide.html
https://google.github.io/styleguide/cppguide.html#Variable_Names

Why the move away from the Google Style Guide convention? You seem to use the var_name_ convention in simplify_expr.cc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these variables are public, it's probably better and more consistent to name it without "_" at the end imo.


void VisitAttrs(tvm::AttrVisitor* v) {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("require_type", &require_type);
}

static constexpr const char* _type_key = "DFPatternCallbackNode";
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
Expand All @@ -58,7 +63,7 @@ class DFPatternCallbackNode : public Object {
*/
class DFPatternCallback : public ObjectRef {
public:
TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback);
TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type);
TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode);
};

Expand Down
7 changes: 7 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,13 @@ TVM_DLL Pass Inline();
*/
TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> entry_functions);

/*!
* \brief Simplify the Relay expression.
*
* \return The pass.
*/
TVM_DLL Pass SimplifyExpr();

} // namespace transform

/*!
Expand Down
23 changes: 15 additions & 8 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,15 @@ class DFPatternCallback:
the callback returns.

Users are expect to inherit from this class and provide a "self.pattern" to match

Parameters
----------
require_type: bool
Whether InferType is required to be run before the callback.
"""
def __init__(self, require_type=False):
self.pattern = None
self.require_type = require_type

def rewrite(self, expr: Expr) -> Expr:
"""
Expand Down Expand Up @@ -727,8 +735,8 @@ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Exp

class _DFPatternCallback(Object):
"""C++ implemenation"""
def __init__(self, pattern, callback):
self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback)
def __init__(self, pattern, callback, require_type):
self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type)


def rewrite(callbacks, expr: Expr) -> Expr:
Expand All @@ -747,12 +755,11 @@ def rewrite(callbacks, expr: Expr) -> Expr:
result : tvm.relay.Expr
The Expression with matched subgraphs rewritten by the callbacks.
"""
if isinstance(callbacks, DFPatternCallback):
tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)]
else:
tmp = []
for callback in callbacks:
tmp.append(_DFPatternCallback(callback.pattern, callback.callback))
callbacks = [callbacks] if isinstance(callbacks, DFPatternCallback) else callbacks
tmp = []
for callback in callbacks:
assert callback.pattern is not None
tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type))

return ffi.rewrite(tmp, expr)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
_reg.register_injective_schedule("take")
_reg.register_injective_schedule("transpose")
_reg.register_injective_schedule("stack")
_reg.register_injective_schedule("_contrib_reverse_reshape")
_reg.register_injective_schedule("contrib_reverse_reshape")
_reg.register_injective_schedule("gather")
_reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def reverse_reshape(data, newshape):
"""
if isinstance(newshape, int):
newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape))
return _make.contrib_reverse_reshape(data, list(newshape))


def gather(data, axis, indices):
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ def DenseToSparse(weight_name, weight_shape):
"""
return _ffi_api.DenseToSparse(weight_name, weight_shape)


def SimplifyFCTranspose(target_weight_name):
"""
Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
Expand All @@ -926,3 +927,15 @@ def SimplifyFCTranspose(target_weight_name):
The registered SimplifyFCTranspose pass.
"""
return _ffi_api.SimplifyFCTranspose(target_weight_name)


def SimplifyExpr():
"""
Simplify the Relay expression, including merging consecutive reshapes.

Returns
-------
ret : tvm.transform.Pass
The registered SimplifyExpr pass.
"""
return _ffi_api.SimplifyExpr()
1 change: 1 addition & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class RelayBuildModule : public runtime::ModuleNode {
}
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -945,10 +945,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
*rv = false;
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::InlinePrimitives());

pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
Expand All @@ -959,6 +961,8 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::AlterOpLayout());
}

// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());

pass_seqs.push_back(transform::FuseOps());
Expand Down
43 changes: 25 additions & 18 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ bool MatchPattern(DFPattern pattern, Expr expr) {

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern);

/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
/*!
* \brief PatternGrouper does pre-rewriting pattern matching and analysis
*
* This class creates a number of groups of matched expressions, ensures they don't overlap, and
* returns them to the caller for post-analysis rewriting.
Expand All @@ -446,7 +447,7 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern)
*/
class PatternGrouper {
public:
/* \brief Internal Group class for storing analysis */
/*! \brief Internal Group class for storing analysis */
struct Group {
Expr root_node;
int gid;
Expand All @@ -456,11 +457,11 @@ class PatternGrouper {
Array<Expr> args;
};

/* \brief Return the group assignments of expressions */
/*! \brief Return the group assignments of expressions */
const std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>& GetGIDAssignments() {
return gid_assignments_;
}
/* \brief Group expressions that match the pattern */
/*! \brief Group expressions that match the pattern */
const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern, const Expr& pre) {
groups_.clear();
gid_assignments_.clear();
Expand All @@ -474,7 +475,7 @@ class PatternGrouper {
}

protected:
/* \brief Iteratively traverse the Expression in pre-order to find subgraphs
/*! \brief Iteratively traverse the Expression in pre-order to find subgraphs
*
* If we traverse the graph in post-order, we can run into situtations where a small subgraph will
* match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in
Expand All @@ -501,7 +502,7 @@ class PatternGrouper {
}
}
}
/* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
/*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
* group overlap analysis */
class MatchExtractor : public ExprMutator {
public:
Expand Down Expand Up @@ -563,7 +564,7 @@ class PatternGrouper {
const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs_;
};

/* \brief Create a group based on a matched expression */
/*! \brief Create a group based on a matched expression */
void CreateGroup(const Expr& expr) {
int var_number = 0;

Expand Down Expand Up @@ -661,7 +662,7 @@ class PatternGrouper {
groups_[group.gid] = std::move(group);
}

/* \brief EmbedConst implements rules for embedding constants into partitioned functions or
/*! \brief EmbedConst implements rules for embedding constants into partitioned functions or
* lifting them into the function arguments.
*
* The rules depend on what pattern the ConstantNode matched.
Expand Down Expand Up @@ -703,21 +704,23 @@ class PatternGrouper {

// Rewrite

DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function) {
DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type) {
ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
n->pattern_ = std::move(pattern);
n->function_ = std::move(function);
n->pattern = std::move(pattern);
n->function = std::move(function);
n->require_type = require_type;
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback")
.set_body_typed([](DFPattern pattern, PackedFunc function) {
return DFPatternCallback(pattern, function);
.set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type) {
return DFPatternCallback(pattern, function, require_type);
});

/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
/*!
* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
* function to rewrite those matches
*
* The class uses PatternGrouper to support the dominator pattern.
Expand All @@ -736,14 +739,17 @@ class PatternRewriter : protected MixedModeMutator {
last = post;
for (auto callback : callbacks) {
callback_ = callback;
if (callback_->require_type) {
post = InferType(post);
}
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(callback_->pattern_, post);
groups_ = grouper.GroupMatches(callback_->pattern, post);
gid_assignments_ = grouper.GetGIDAssignments();
memo_.clear();
post = this->VisitExpr(post);
count++;
}
} while (last != post || count >= 100);
} while (!StructuralEqual()(last, post) || count >= 100);
if (count >= 100) {
throw("Observed 100 rewrite passes, possible conflicting passes?");
}
Expand All @@ -765,7 +771,7 @@ class PatternRewriter : protected MixedModeMutator {
node_map.insert({kv.first, tmp});
}
// run the user callback function
return callback_->function_(pre, post, Map<DFPattern, Array<Expr>>(node_map));
return callback_->function(pre, post, Map<DFPattern, Array<Expr>>(node_map));
}
return post;
}
Expand All @@ -781,7 +787,8 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr) {

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns);

/* \brief PatternPartitioner replaces expressions that match a pattern with function call that
/*!
* \brief PatternPartitioner replaces expressions that match a pattern with function call that
* perform the same computation but allow for further analysis and lowering.
*
* The class uses PatternGrouper to support the dominator pattern.
Expand Down
6 changes: 3 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2573,13 +2573,13 @@ Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = true;
static const Op& op = Op::Get("_contrib_reverse_reshape");
static const Op& op = Op::Get("contrib_reverse_reshape");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape").set_body_typed(MakeReverseReshape);
TVM_REGISTER_GLOBAL("relay.op._make.contrib_reverse_reshape").set_body_typed(MakeReverseReshape);

RELAY_REGISTER_OP("_contrib_reverse_reshape")
RELAY_REGISTER_OP("contrib_reverse_reshape")
.describe(R"code(Reshapes the input array where the special values are inferred from
right to left.

Expand Down
Loading