Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay][Pass] Merge two consecutive reshape ops #6052

Merged
merged 13 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions include/tvm/relay/dataflow_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@ class DFPatternCallback;
class DFPatternCallbackNode : public Object {
public:
/*! \brief Pattern this callback matches */
DFPattern pattern_;
DFPattern pattern;
/*! \brief Function to call when finding a matched expression */
PackedFunc function_;
PackedFunc function;
/*! \brief Require InferType to be run before the callback */
bool require_type;
Comment on lines +45 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://tvm.apache.org/docs/contribute/code_guide.html
https://google.github.io/styleguide/cppguide.html#Variable_Names

Why the move away from the Google Style Guide convention? You seem to use the var_name_ convention in simplify_expr.cc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because these variables are public, it's probably better and more consistent to name it without "_" at the end imo.


void VisitAttrs(tvm::AttrVisitor* v) {}
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("pattern", &pattern);
v->Visit("require_type", &require_type);
}

static constexpr const char* _type_key = "DFPatternCallbackNode";
TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object);
Expand All @@ -58,7 +63,7 @@ class DFPatternCallbackNode : public Object {
*/
class DFPatternCallback : public ObjectRef {
public:
TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback);
TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type);
TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode);
};

Expand Down
19 changes: 15 additions & 4 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,7 +687,15 @@ class DFPatternCallback:
the callback returns.

Users are expect to inherit from this class and provide a "self.pattern" to match

Parameters
----------
require_type: bool
Whether InferType is required to be run before the callback.
"""
def __init__(self, require_type=False):
self.pattern = None
self.require_type = require_type

def rewrite(self, expr: Expr) -> Expr:
"""
Expand Down Expand Up @@ -727,8 +735,8 @@ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Exp

class _DFPatternCallback(Object):
"""C++ implemenation"""
def __init__(self, pattern, callback):
self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback)
def __init__(self, pattern, callback, require_type):
self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type)


def rewrite(callbacks, expr: Expr) -> Expr:
Expand All @@ -748,11 +756,14 @@ def rewrite(callbacks, expr: Expr) -> Expr:
The Expression with matched subgraphs rewritten by the callbacks.
"""
if isinstance(callbacks, DFPatternCallback):
tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)]
assert callbacks.pattern is not None
tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback, callbacks.require_type)]
else:
tmp = []
for callback in callbacks:
tmp.append(_DFPatternCallback(callback.pattern, callback.callback))
assert callback.pattern is not None
tmp.append(_DFPatternCallback(callback.pattern, callback.callback,
callback.require_type))
icemelon marked this conversation as resolved.
Show resolved Hide resolved

return ffi.rewrite(tmp, expr)

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
_reg.register_injective_schedule("take")
_reg.register_injective_schedule("transpose")
_reg.register_injective_schedule("stack")
_reg.register_injective_schedule("_contrib_reverse_reshape")
_reg.register_injective_schedule("contrib_reverse_reshape")
_reg.register_injective_schedule("gather")
_reg.register_injective_schedule("gather_nd")
_reg.register_injective_schedule("sequence_mask")
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def reverse_reshape(data, newshape):
"""
if isinstance(newshape, int):
newshape = [newshape]
return _make._contrib_reverse_reshape(data, list(newshape))
return _make.contrib_reverse_reshape(data, list(newshape))


def gather(data, axis, indices):
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,7 @@ def DenseToSparse(weight_name, weight_shape):
"""
return _ffi_api.DenseToSparse(weight_name, weight_shape)


def SimplifyFCTranspose(target_weight_name):
"""
Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)```
Expand All @@ -926,3 +927,15 @@ def SimplifyFCTranspose(target_weight_name):
The registered SimplifyFCTranspose pass.
"""
return _ffi_api.SimplifyFCTranspose(target_weight_name)


def SimplifyExpr():
"""
Simplify the Relay expression, including merging consecutive reshapes.

Returns
-------
ret : tvm.transform.Pass
The registered SimplifyExpr pass.
"""
return _ffi_api.SimplifyExpr()
43 changes: 25 additions & 18 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ bool MatchPattern(DFPattern pattern, Expr expr) {

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern);

/* \brief PatternGrouper does pre-rewriting pattern matching and analysis
/*!
* \brief PatternGrouper does pre-rewriting pattern matching and analysis
*
* This class creates a number of groups of matched expressions, ensures they don't overlap, and
* returns them to the caller for post-analysis rewriting.
Expand All @@ -446,7 +447,7 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern)
*/
class PatternGrouper {
public:
/* \brief Internal Group class for storing analysis */
/*! \brief Internal Group class for storing analysis */
struct Group {
Expr root_node;
int gid;
Expand All @@ -456,11 +457,11 @@ class PatternGrouper {
Array<Expr> args;
};

/* \brief Return the group assignments of expressions */
/*! \brief Return the group assignments of expressions */
const std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual>& GetGIDAssignments() {
return gid_assignments_;
}
/* \brief Group expressions that match the pattern */
/*! \brief Group expressions that match the pattern */
const std::unordered_map<int, Group>& GroupMatches(const DFPattern& pattern, const Expr& pre) {
groups_.clear();
gid_assignments_.clear();
Expand All @@ -474,7 +475,7 @@ class PatternGrouper {
}

protected:
/* \brief Iteratively traverse the Expression in pre-order to find subgraphs
/*! \brief Iteratively traverse the Expression in pre-order to find subgraphs
*
* If we traverse the graph in post-order, we can run into situtations where a small subgraph will
* match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in
Expand All @@ -501,7 +502,7 @@ class PatternGrouper {
}
}
}
/* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
/*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform
* group overlap analysis */
class MatchExtractor : public ExprMutator {
public:
Expand Down Expand Up @@ -563,7 +564,7 @@ class PatternGrouper {
const std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs_;
};

/* \brief Create a group based on a matched expression */
/*! \brief Create a group based on a matched expression */
void CreateGroup(const Expr& expr) {
int var_number = 0;

Expand Down Expand Up @@ -661,7 +662,7 @@ class PatternGrouper {
groups_[group.gid] = std::move(group);
}

/* \brief EmbedConst implements rules for embedding constants into partitioned functions or
/*! \brief EmbedConst implements rules for embedding constants into partitioned functions or
* lifting them into the function arguments.
*
* The rules depend on what pattern the ConstantNode matched.
Expand Down Expand Up @@ -703,21 +704,23 @@ class PatternGrouper {

// Rewrite

DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function) {
DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type) {
ObjectPtr<DFPatternCallbackNode> n = make_object<DFPatternCallbackNode>();
n->pattern_ = std::move(pattern);
n->function_ = std::move(function);
n->pattern = std::move(pattern);
n->function = std::move(function);
n->require_type = require_type;
data_ = std::move(n);
}

TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback")
.set_body_typed([](DFPattern pattern, PackedFunc function) {
return DFPatternCallback(pattern, function);
.set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type) {
return DFPatternCallback(pattern, function, require_type);
});

/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
/*!
* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback
* function to rewrite those matches
*
* The class uses PatternGrouper to support the dominator pattern.
Expand All @@ -736,14 +739,17 @@ class PatternRewriter : protected MixedModeMutator {
last = post;
for (auto callback : callbacks) {
callback_ = callback;
if (callback_->require_type) {
post = InferType(post);
}
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(callback_->pattern_, post);
groups_ = grouper.GroupMatches(callback_->pattern, post);
gid_assignments_ = grouper.GetGIDAssignments();
memo_.clear();
post = this->VisitExpr(post);
count++;
}
} while (last != post || count >= 100);
} while (!StructuralEqual()(last, post) || count >= 100);
if (count >= 100) {
throw("Observed 100 rewrite passes, possible conflicting passes?");
}
Expand All @@ -765,7 +771,7 @@ class PatternRewriter : protected MixedModeMutator {
node_map.insert({kv.first, tmp});
}
// run the user callback function
return callback_->function_(pre, post, Map<DFPattern, Array<Expr>>(node_map));
return callback_->function(pre, post, Map<DFPattern, Array<Expr>>(node_map));
}
return post;
}
Expand All @@ -781,7 +787,8 @@ Expr RewritePatterns(Array<DFPatternCallback> callbacks, Expr expr) {

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns);

/* \brief PatternPartitioner replaces expressions that match a pattern with function call that
/*!
* \brief PatternPartitioner replaces expressions that match a pattern with function call that
* perform the same computation but allow for further analysis and lowering.
*
* The class uses PatternGrouper to support the dominator pattern.
Expand Down
6 changes: 3 additions & 3 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2573,13 +2573,13 @@ Expr MakeReverseReshape(Expr data, Array<Integer> newshape) {
auto attrs = make_object<ReshapeAttrs>();
attrs->newshape = std::move(newshape);
attrs->reverse = true;
static const Op& op = Op::Get("_contrib_reverse_reshape");
static const Op& op = Op::Get("contrib_reverse_reshape");
return Call(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape").set_body_typed(MakeReverseReshape);
TVM_REGISTER_GLOBAL("relay.op._make.contrib_reverse_reshape").set_body_typed(MakeReverseReshape);

RELAY_REGISTER_OP("_contrib_reverse_reshape")
RELAY_REGISTER_OP("contrib_reverse_reshape")
.describe(R"code(Reshapes the input array where the special values are inferred from
right to left.

Expand Down
124 changes: 124 additions & 0 deletions src/relay/transforms/simplify_expr.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file src/relay/transforms/simplify_expr.cc
* \brief A pass for simplifying the Relay expression.
*/

#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/transform.h>
#include <tvm/relay/dataflow_matcher.h>
#include <tvm/support/logging.h>
#include "../op/tensor/transform.h"

namespace tvm {
namespace relay {

static Op reshape_op = Op::Get("reshape");
static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape");

/*!
* \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
* and merges into one reshape op.
*/
class SimplifyReshape {
public:
SimplifyReshape() {
x_ = WildcardPattern(make_object<WildcardPatternNode>());
auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op));
pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {});
}

Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, Array<Expr>>& node_map) {
auto x = node_map[x_][0];
bool const_shape = true;
Array<Integer> newshape;
for (auto dim : Downcast<TensorType>(pre->checked_type())->shape) {
if (dim.as<IntImmNode>() == nullptr) {
const_shape = false;
break;
}
newshape.push_back(Downcast<Integer>(dim));
}
if (const_shape) {
return MakeReshape(x, newshape);
}
return post;
}

DFPattern pattern() const { return pattern_; }

private:
/*! \brief Pattern input */
DFPattern x_;
/*! \brief Pattern for consecutive reshape or reverse_reshape ops */
DFPattern pattern_;

};

/*!
* \brief ExprSimplifier simplifies the Relay expression.
*/
class ExprSimplifier {
public:
ExprSimplifier() {
auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) {
Expr pre = args[0];
Expr post = args[1];
Map<DFPattern, Array<Expr>> node_map = args[2];
*rv = simplify_reshape_.callback(pre, post, node_map);
};
callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe have SimplifyReshape directly inherit DFPatternCallback? You could fold this directly into that and keep it out of the main Simplifier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason that I didn't inherit directly from DFPatternCallback is because you need to create the pattern somewhere else as it's required in the DFPatternCallback constructor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:/ I think I focused too much on the Python API and left an Ugly C++ API. I'll see if I can clean that up in a follow up PR. Thanks!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. :)

true));
}

Expr Simplify(const Expr& expr) {
return RewritePatterns(callbacks_, expr);
}

private:
/*! \brief Simplify reshape pattern */
SimplifyReshape simplify_reshape_;
/*! \brief Callbacks for expr simplification */
Array<DFPatternCallback> callbacks_;
};

Expr SimplifyExpr(const Expr& expr, const IRModule& module) {
return ExprSimplifier().Simplify(expr);
}

namespace transform {

Pass SimplifyExpr() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(SimplifyExpr(f, m));
};
return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"});
}

TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr").set_body_typed(SimplifyExpr);

} // namespace transform

} // namespace relay
} // namespace tvm
Loading