From 8bce8dace6d4d2ff6aa32ddb382d454e2bd2a5cb Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Mon, 13 Jul 2020 23:19:35 -0700 Subject: [PATCH 01/13] [Relay][Pass] Merge two consecutive reshape op --- python/tvm/relay/op/_transform.py | 2 +- python/tvm/relay/op/transform.py | 2 +- python/tvm/relay/transform/__init__.py | 1 + python/tvm/relay/transform/simplify_expr.py | 45 ++++++++++++++++++ src/relay/ir/dataflow_matcher.cc | 4 +- src/relay/op/tensor/transform.cc | 6 +-- tests/python/relay/test_pass_simplify_expr.py | 47 +++++++++++++++++++ 7 files changed, 100 insertions(+), 7 deletions(-) create mode 100644 python/tvm/relay/transform/simplify_expr.py create mode 100644 tests/python/relay/test_pass_simplify_expr.py diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index 878b82a19a36..dc1265870475 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -52,7 +52,7 @@ _reg.register_injective_schedule("take") _reg.register_injective_schedule("transpose") _reg.register_injective_schedule("stack") -_reg.register_injective_schedule("_contrib_reverse_reshape") +_reg.register_injective_schedule("contrib_reverse_reshape") _reg.register_injective_schedule("gather") _reg.register_injective_schedule("gather_nd") _reg.register_injective_schedule("sequence_mask") diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 173db64de258..4f7c83464fc2 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -908,7 +908,7 @@ def reverse_reshape(data, newshape): """ if isinstance(newshape, int): newshape = [newshape] - return _make._contrib_reverse_reshape(data, list(newshape)) + return _make.contrib_reverse_reshape(data, list(newshape)) def gather(data, axis, indices): diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index 138a36611c6f..dd0eefa773eb 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,3 +19,4 @@ # transformation passes from .transform import * from . import memory_alloc +from .simplify_expr import SimplifyExpr diff --git a/python/tvm/relay/transform/simplify_expr.py b/python/tvm/relay/transform/simplify_expr.py new file mode 100644 index 000000000000..337b7b924925 --- /dev/null +++ b/python/tvm/relay/transform/simplify_expr.py @@ -0,0 +1,45 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=unused-argument +""" +A pass for simplifying the Relay expression. +""" +from . import transform +from ..dataflow_pattern import wildcard, is_op, DFPatternCallback, rewrite +from .. import op as _op + +class SimplifyReshapeCallback(DFPatternCallback): + """Callback to merge consecutive reshape ops""" + def __init__(self): + self.x = wildcard() + reshape1 = is_op("reshape") | is_op("contrib_reverse_reshape") + reshape2 = is_op("reshape") | is_op("contrib_reverse_reshape") + self.pattern = reshape1(reshape2(self.x)) + + def callback(self, pre, post, node_map): + x = node_map[self.x][0] + return _op.reshape(x, newshape=pre.checked_type.shape) + + +@transform.function_pass(opt_level=0, required=["InferType"]) +class SimplifyExpr: + """ A pass to simplify the Relay expression.""" + def __init__(self): + self.callbacks = [SimplifyReshapeCallback()] + + def transform_function(self, func, mod, _): + return rewrite(self.callbacks, func) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 57b3013fd04b..17030b09fff0 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -740,10 +740,10 @@ class PatternRewriter : protected MixedModeMutator { groups_ = grouper.GroupMatches(callback_->pattern_, post); gid_assignments_ = grouper.GetGIDAssignments(); memo_.clear(); - post = this->VisitExpr(post); + post = InferType(this->VisitExpr(post)); count++; } - } while (last != post || count >= 100); + } while (!StructuralEqual()(last, post) || count >= 100); if (count >= 100) { throw("Observed 100 rewrite passes, possible conflicting passes?"); } diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b1c2d8b23373..fa16fd10b78e 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2573,13 +2573,13 @@ Expr MakeReverseReshape(Expr data, Array newshape) { auto attrs = make_object(); attrs->newshape = std::move(newshape); attrs->reverse = true; - static const Op& op = Op::Get("_contrib_reverse_reshape"); + static const Op& op = Op::Get("contrib_reverse_reshape"); return Call(op, {data}, Attrs(attrs), {}); } -TVM_REGISTER_GLOBAL("relay.op._make._contrib_reverse_reshape").set_body_typed(MakeReverseReshape); +TVM_REGISTER_GLOBAL("relay.op._make.contrib_reverse_reshape").set_body_typed(MakeReverseReshape); -RELAY_REGISTER_OP("_contrib_reverse_reshape") +RELAY_REGISTER_OP("contrib_reverse_reshape") .describe(R"code(Reshapes the input array where the special values are inferred from right to left. diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py new file mode 100644 index 000000000000..3cc0b379df72 --- /dev/null +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -0,0 +1,47 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay import transform +from tvm.relay.testing import run_opt_pass + + +def test_simplify_reshape(): + def before(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") + y = relay.nn.conv2d(x, w, padding=(1, 1)) + y = relay.reshape(y, newshape=(1, 16, -1)) + y = relay.reshape(y, newshape=(4, 8, -1, 16)) + y = relay.reverse_reshape(y, newshape=(32, 0, -1)) + return relay.Function([x, w], y) + + def expected(): + x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") + w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") + y = relay.nn.conv2d(x, w, padding=(1, 1)) + y = relay.reshape(y, newshape=(32, 16, 16)) + return relay.Function([x, w], y) + + z = before() + zz = run_opt_pass(z, transform.SimplifyExpr()) + after = run_opt_pass(expected(), transform.InferType()) + assert tvm.ir.structural_equal(zz, after) + + +if __name__ == "__main__": + test_simplify_reshape() From 8b693885a3238c792e6566623cfa75da2791a972 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 14 Jul 2020 14:11:48 -0700 Subject: [PATCH 02/13] comments --- include/tvm/relay/dataflow_matcher.h | 13 +- python/tvm/relay/dataflow_pattern/__init__.py | 19 ++- python/tvm/relay/transform/__init__.py | 1 - python/tvm/relay/transform/simplify_expr.py | 45 ------- python/tvm/relay/transform/transform.py | 13 ++ src/relay/ir/dataflow_matcher.cc | 43 +++--- src/relay/transforms/simplify_expr.cc | 124 ++++++++++++++++++ tests/python/relay/test_dataflow_pattern.py | 12 ++ tests/python/relay/test_pass_simplify_expr.py | 15 ++- 9 files changed, 212 insertions(+), 73 deletions(-) delete mode 100644 python/tvm/relay/transform/simplify_expr.py create mode 100644 src/relay/transforms/simplify_expr.cc diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index bb53ad32d9f4..8b72836b3334 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -42,11 +42,16 @@ class DFPatternCallback; class DFPatternCallbackNode : public Object { public: /*! \brief Pattern this callback matches */ - DFPattern pattern_; + DFPattern pattern; /*! \brief Function to call when finding a matched expression */ - PackedFunc function_; + PackedFunc function; + /*! \brief Require InferType to be run before the callback */ + bool require_type; - void VisitAttrs(tvm::AttrVisitor* v) {} + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("require_type", &require_type); + } static constexpr const char* _type_key = "DFPatternCallbackNode"; TVM_DECLARE_BASE_OBJECT_INFO(DFPatternCallbackNode, Object); @@ -58,7 +63,7 @@ class DFPatternCallbackNode : public Object { */ class DFPatternCallback : public ObjectRef { public: - TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback); + TVM_DLL DFPatternCallback(DFPattern pattern, PackedFunc callback, bool require_type); TVM_DEFINE_OBJECT_REF_METHODS(DFPatternCallback, ObjectRef, DFPatternCallbackNode); }; diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 317d28e1dbea..37a924427b39 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -687,7 +687,15 @@ class DFPatternCallback: the callback returns. Users are expect to inherit from this class and provide a "self.pattern" to match + + Parameters + ---------- + require_type: bool + Whether InferType is required to be run before the callback. """ + def __init__(self, require_type=False): + self.pattern = None + self.require_type = require_type def rewrite(self, expr: Expr) -> Expr: """ @@ -727,8 +735,8 @@ def callback(self, pre: Expr, post: Expr, node_map: tvm.ir.container.Map) -> Exp class _DFPatternCallback(Object): """C++ implemenation""" - def __init__(self, pattern, callback): - self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback) + def __init__(self, pattern, callback, require_type): + self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type) def rewrite(callbacks, expr: Expr) -> Expr: @@ -748,11 +756,14 @@ def rewrite(callbacks, expr: Expr) -> Expr: The Expression with matched subgraphs rewritten by the callbacks. """ if isinstance(callbacks, DFPatternCallback): - tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback)] + assert callbacks.pattern is not None + tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback, callbacks.require_type)] else: tmp = [] for callback in callbacks: - tmp.append(_DFPatternCallback(callback.pattern, callback.callback)) + assert callback.pattern is not None + tmp.append(_DFPatternCallback(callback.pattern, callback.callback, + callback.require_type)) return ffi.rewrite(tmp, expr) diff --git a/python/tvm/relay/transform/__init__.py b/python/tvm/relay/transform/__init__.py index dd0eefa773eb..138a36611c6f 100644 --- a/python/tvm/relay/transform/__init__.py +++ b/python/tvm/relay/transform/__init__.py @@ -19,4 +19,3 @@ # transformation passes from .transform import * from . import memory_alloc -from .simplify_expr import SimplifyExpr diff --git a/python/tvm/relay/transform/simplify_expr.py b/python/tvm/relay/transform/simplify_expr.py deleted file mode 100644 index 337b7b924925..000000000000 --- a/python/tvm/relay/transform/simplify_expr.py +++ /dev/null @@ -1,45 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint: disable=unused-argument -""" -A pass for simplifying the Relay expression. -""" -from . import transform -from ..dataflow_pattern import wildcard, is_op, DFPatternCallback, rewrite -from .. import op as _op - -class SimplifyReshapeCallback(DFPatternCallback): - """Callback to merge consecutive reshape ops""" - def __init__(self): - self.x = wildcard() - reshape1 = is_op("reshape") | is_op("contrib_reverse_reshape") - reshape2 = is_op("reshape") | is_op("contrib_reverse_reshape") - self.pattern = reshape1(reshape2(self.x)) - - def callback(self, pre, post, node_map): - x = node_map[self.x][0] - return _op.reshape(x, newshape=pre.checked_type.shape) - - -@transform.function_pass(opt_level=0, required=["InferType"]) -class SimplifyExpr: - """ A pass to simplify the Relay expression.""" - def __init__(self): - self.callbacks = [SimplifyReshapeCallback()] - - def transform_function(self, func, mod, _): - return rewrite(self.callbacks, func) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index ede63808d4fd..7db068785ba6 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -909,6 +909,7 @@ def DenseToSparse(weight_name, weight_shape): """ return _ffi_api.DenseToSparse(weight_name, weight_shape) + def SimplifyFCTranspose(target_weight_name): """ Rewrite ```y = nn.dense(x, transpose(w, [1, 0]))``` to ```y = nn.dense(x, wt)``` @@ -926,3 +927,15 @@ def SimplifyFCTranspose(target_weight_name): The registered SimplifyFCTranspose pass. """ return _ffi_api.SimplifyFCTranspose(target_weight_name) + + +def SimplifyExpr(): + """ + Simplify the Relay expression, including merging consecutive reshapes. + + Returns + ------- + ret : tvm.transform.Pass + The registered SimplifyExpr pass. + """ + return _ffi_api.SimplifyExpr() diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 17030b09fff0..d16172e6fbdc 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -436,7 +436,8 @@ bool MatchPattern(DFPattern pattern, Expr expr) { TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern); -/* \brief PatternGrouper does pre-rewriting pattern matching and analysis +/*! + * \brief PatternGrouper does pre-rewriting pattern matching and analysis * * This class creates a number of groups of matched expressions, ensures they don't overlap, and * returns them to the caller for post-analysis rewriting. @@ -446,7 +447,7 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.match").set_body_typed(MatchPattern) */ class PatternGrouper { public: - /* \brief Internal Group class for storing analysis */ + /*! \brief Internal Group class for storing analysis */ struct Group { Expr root_node; int gid; @@ -456,11 +457,11 @@ class PatternGrouper { Array args; }; - /* \brief Return the group assignments of expressions */ + /*! \brief Return the group assignments of expressions */ const std::unordered_map& GetGIDAssignments() { return gid_assignments_; } - /* \brief Group expressions that match the pattern */ + /*! \brief Group expressions that match the pattern */ const std::unordered_map& GroupMatches(const DFPattern& pattern, const Expr& pre) { groups_.clear(); gid_assignments_.clear(); @@ -474,7 +475,7 @@ class PatternGrouper { } protected: - /* \brief Iteratively traverse the Expression in pre-order to find subgraphs + /*! \brief Iteratively traverse the Expression in pre-order to find subgraphs * * If we traverse the graph in post-order, we can run into situtations where a small subgraph will * match the pattern. Due to options like AltPattern, a larger subgraph with more nodes later in @@ -501,7 +502,7 @@ class PatternGrouper { } } } - /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform + /*! \brief Creates a new set of nodes based on Group inputs, used to create functions and perform * group overlap analysis */ class MatchExtractor : public ExprMutator { public: @@ -563,7 +564,7 @@ class PatternGrouper { const std::unordered_map inputs_; }; - /* \brief Create a group based on a matched expression */ + /*! \brief Create a group based on a matched expression */ void CreateGroup(const Expr& expr) { int var_number = 0; @@ -661,7 +662,7 @@ class PatternGrouper { groups_[group.gid] = std::move(group); } - /* \brief EmbedConst implements rules for embedding constants into partitioned functions or + /*! \brief EmbedConst implements rules for embedding constants into partitioned functions or * lifting them into the function arguments. * * The rules depend on what pattern the ConstantNode matched. @@ -703,21 +704,23 @@ class PatternGrouper { // Rewrite -DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function) { +DFPatternCallback::DFPatternCallback(DFPattern pattern, PackedFunc function, bool require_type) { ObjectPtr n = make_object(); - n->pattern_ = std::move(pattern); - n->function_ = std::move(function); + n->pattern = std::move(pattern); + n->function = std::move(function); + n->require_type = require_type; data_ = std::move(n); } TVM_REGISTER_NODE_TYPE(DFPatternCallbackNode); TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") - .set_body_typed([](DFPattern pattern, PackedFunc function) { - return DFPatternCallback(pattern, function); + .set_body_typed([](DFPattern pattern, PackedFunc function, bool require_type) { + return DFPatternCallback(pattern, function, require_type); }); -/* \brief PatternRewriter rewrites the expression by finding matches and allowing user callback +/*! + * \brief PatternRewriter rewrites the expression by finding matches and allowing user callback * function to rewrite those matches * * The class uses PatternGrouper to support the dominator pattern. @@ -736,11 +739,14 @@ class PatternRewriter : protected MixedModeMutator { last = post; for (auto callback : callbacks) { callback_ = callback; + if (callback_->require_type) { + post = InferType(post); + } auto grouper = PatternGrouper(); - groups_ = grouper.GroupMatches(callback_->pattern_, post); + groups_ = grouper.GroupMatches(callback_->pattern, post); gid_assignments_ = grouper.GetGIDAssignments(); memo_.clear(); - post = InferType(this->VisitExpr(post)); + post = this->VisitExpr(post); count++; } } while (!StructuralEqual()(last, post) || count >= 100); @@ -765,7 +771,7 @@ class PatternRewriter : protected MixedModeMutator { node_map.insert({kv.first, tmp}); } // run the user callback function - return callback_->function_(pre, post, Map>(node_map)); + return callback_->function(pre, post, Map>(node_map)); } return post; } @@ -781,7 +787,8 @@ Expr RewritePatterns(Array callbacks, Expr expr) { TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns); -/* \brief PatternPartitioner replaces expressions that match a pattern with function call that +/*! + * \brief PatternPartitioner replaces expressions that match a pattern with function call that * perform the same computation but allow for further analysis and lowering. * * The class uses PatternGrouper to support the dominator pattern. diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc new file mode 100644 index 000000000000..c3af05ea759f --- /dev/null +++ b/src/relay/transforms/simplify_expr.cc @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/transforms/simplify_expr.cc + * \brief A pass for simplifying the Relay expression. + */ + +#include +#include +#include +#include +#include +#include "../op/tensor/transform.h" + +namespace tvm { +namespace relay { + +static Op reshape_op = Op::Get("reshape"); +static Op reverse_reshape_op = Op::Get("contrib_reverse_reshape"); + +/*! + * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops, + * and merges into one reshape op. + */ +class SimplifyReshape { + public: + SimplifyReshape() { + x_ = WildcardPattern(make_object()); + auto reshape1 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op)); + auto reshape2 = AltPattern(ExprPattern(reshape_op), ExprPattern(reverse_reshape_op)); + pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, {})}, Attrs{}, {}); + } + + Expr callback(const Expr& pre, const Expr& post, const Map>& node_map) { + auto x = node_map[x_][0]; + bool const_shape = true; + Array newshape; + for (auto dim : Downcast(pre->checked_type())->shape) { + if (dim.as() == nullptr) { + const_shape = false; + break; + } + newshape.push_back(Downcast(dim)); + } + if (const_shape) { + return MakeReshape(x, newshape); + } + return post; + } + + DFPattern pattern() const { return pattern_; } + + private: + /*! \brief Pattern input */ + DFPattern x_; + /*! \brief Pattern for consecutive reshape or reverse_reshape ops */ + DFPattern pattern_; + +}; + +/*! + * \brief ExprSimplifier simplifies the Relay expression. + */ +class ExprSimplifier { + public: + ExprSimplifier() { + auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) { + Expr pre = args[0]; + Expr post = args[1]; + Map> node_map = args[2]; + *rv = simplify_reshape_.callback(pre, post, node_map); + }; + callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), + true)); + } + + Expr Simplify(const Expr& expr) { + return RewritePatterns(callbacks_, expr); + } + + private: + /*! \brief Simplify reshape pattern */ + SimplifyReshape simplify_reshape_; + /*! \brief Callbacks for expr simplification */ + Array callbacks_; +}; + +Expr SimplifyExpr(const Expr& expr, const IRModule& module) { + return ExprSimplifier().Simplify(expr); +} + +namespace transform { + +Pass SimplifyExpr() { + runtime::TypedPackedFunc pass_func = + [=](Function f, IRModule m, PassContext pc) { + return Downcast(SimplifyExpr(f, m)); + }; + return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr").set_body_typed(SimplifyExpr); + +} // namespace transform + +} // namespace relay +} // namespace tvm \ No newline at end of file diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index f390b720b80a..34a098731b86 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -599,6 +599,7 @@ def test_rewrite(): class TestRewrite(DFPatternCallback): def __init__(self): + super(TestRewrite, self).__init__() self.pattern = add_pattern def callback(self, pre, post, node_map): @@ -617,6 +618,7 @@ def test_rewrite_func(): class TestRewrite(DFPatternCallback): def __init__(self): + super(TestRewrite, self).__init__() self.pattern = add_pattern def callback(self, pre, post, node_map): @@ -634,6 +636,7 @@ def callback(self, pre, post, node_map): def test_nested_rewrite(): class PatternCallback(DFPatternCallback): def __init__(self, pattern): + super(PatternCallback, self).__init__() self.pattern = pattern def callback(self, pre, post, node_map): @@ -682,6 +685,7 @@ def test_not_fuse_multi_diamond(): class BatchnormCallback(DFPatternCallback): def __init__(self): + super(BatchnormCallback, self).__init__() self.x = wildcard() self.var = wildcard() self.mean = wildcard() @@ -798,6 +802,7 @@ def test_fuse_batchnorm_commutation(): def test_quadruple_rewrite_dominator(): class DominatorRemovalCallback(DFPatternCallback): def __init__(self): + super(DominatorRemovalCallback, self).__init__() self.inp = wildcard() self.weight = wildcard() is_conv2d = is_op('nn.conv2d')(self.inp, self.weight) @@ -860,31 +865,37 @@ def callback(self, pre, post, node_map): class AddCallback(ElwiseNullCallback): def __init__(self): + super(AddCallback, self).__init__() self.x = wildcard() self.pattern = self.x + zero class SubCallback(ElwiseNullCallback): def __init__(self): + super(SubCallback, self).__init__() self.x = wildcard() self.pattern = self.x - zero class MulCallback(ElwiseNullCallback): def __init__(self): + super(MulCallback, self).__init__() self.x = wildcard() self.pattern = self.x * one class DivCallback(ElwiseNullCallback): def __init__(self): + super(DivCallback, self).__init__() self.x = wildcard() self.pattern = self.x / one class MulZeroCallback(ElwiseNullCallback): def __init__(self): + super(MulZeroCallback, self).__init__() self.x = zero self.pattern = self.x * wildcard() class ZeroDivCallback(ElwiseNullCallback): def __init__(self): + super(ZeroDivCallback, self).__init__() self.x = zero self.pattern = self.x / wildcard() @@ -1265,6 +1276,7 @@ def test_match_match(): add_pattern = is_op('add')(wildcard(), wildcard()) class TestRewrite(DFPatternCallback): def __init__(self): + super(TestRewrite, self).__init__() self.pattern = add_pattern def callback(self, pre, post, node_map): return post.args[0] - post.args[1] diff --git a/tests/python/relay/test_pass_simplify_expr.py b/tests/python/relay/test_pass_simplify_expr.py index 3cc0b379df72..e934c11a6370 100644 --- a/tests/python/relay/test_pass_simplify_expr.py +++ b/tests/python/relay/test_pass_simplify_expr.py @@ -19,7 +19,6 @@ from tvm.relay import transform from tvm.relay.testing import run_opt_pass - def test_simplify_reshape(): def before(): x = relay.var("x", shape=(1, 16, 16, 16), dtype="float32") @@ -37,11 +36,25 @@ def expected(): y = relay.reshape(y, newshape=(32, 16, 16)) return relay.Function([x, w], y) + def symbolic(): + b = tvm.te.size_var('b') + x = relay.var("x", shape=(b, 16, 16, 16), dtype="float32") + w = relay.var("w", shape=(32, 16, 3, 3), dtype="float32") + y = relay.nn.conv2d(x, w, padding=(1, 1)) + y = relay.reshape(y, newshape=(1, 16, -1)) + y = relay.reshape(y, newshape=(4, 8, -1, 16)) + y = relay.reverse_reshape(y, newshape=(32, 0, -1)) + return relay.Function([x, w], y) + z = before() zz = run_opt_pass(z, transform.SimplifyExpr()) after = run_opt_pass(expected(), transform.InferType()) assert tvm.ir.structural_equal(zz, after) + z = symbolic() + zz = run_opt_pass(z, transform.SimplifyExpr()) + after = run_opt_pass(symbolic(), transform.InferType()) + assert tvm.ir.structural_equal(zz, after) if __name__ == "__main__": test_simplify_reshape() From 7b35182090b0ee7a2c927eb0e639213b02ea0f1c Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 14 Jul 2020 14:16:31 -0700 Subject: [PATCH 03/13] lint --- src/relay/transforms/simplify_expr.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index c3af05ea759f..47d5b0c62ff8 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -72,7 +72,6 @@ class SimplifyReshape { DFPattern x_; /*! \brief Pattern for consecutive reshape or reverse_reshape ops */ DFPattern pattern_; - }; /*! @@ -121,4 +120,4 @@ TVM_REGISTER_GLOBAL("relay._transform.SimplifyExpr").set_body_typed(SimplifyExpr } // namespace transform } // namespace relay -} // namespace tvm \ No newline at end of file +} // namespace tvm From 842b5a724fbd4a5548d01dd16db6d25aa7ea4a60 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 14 Jul 2020 14:47:30 -0700 Subject: [PATCH 04/13] Add pass to optimization pipeline --- include/tvm/relay/transform.h | 7 +++++++ src/relay/backend/build_module.cc | 1 + src/relay/backend/vm/compiler.cc | 4 ++++ src/relay/transforms/simplify_expr.cc | 4 ++-- 4 files changed, 14 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 1b8b31aee5d1..d995301c1688 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -360,6 +360,13 @@ TVM_DLL Pass Inline(); */ TVM_DLL Pass RemoveUnusedFunctions(Array entry_functions); +/*! + * \brief Simplify the Relay expression. + * + * \return The pass. + */ +TVM_DLL Pass SimplifyExpr(); + } // namespace transform /*! diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index b589bcce99fc..b57c0eb8cbdb 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -276,6 +276,7 @@ class RelayBuildModule : public runtime::ModuleNode { } }); pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::SimplifyExpr()); pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::CombineParallelDense(3)); pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index d01dbda24a4c..585b8033be8d 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -945,10 +945,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe *rv = false; }); pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); + pass_seqs.push_back(transform::SimplifyExpr()); pass_seqs.push_back(transform::InlinePrimitives()); pass_seqs.push_back(transform::CombineParallelConv2D(3)); pass_seqs.push_back(transform::CombineParallelDense(3)); + pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FoldScaleAxis()); pass_seqs.push_back(transform::CanonicalizeCast()); @@ -959,6 +961,8 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::AlterOpLayout()); } + // Fast math optimizations. + pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); pass_seqs.push_back(transform::FuseOps()); diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index 47d5b0c62ff8..f16ec597799e 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -101,7 +101,7 @@ class ExprSimplifier { Array callbacks_; }; -Expr SimplifyExpr(const Expr& expr, const IRModule& module) { +Expr SimplifyExpr(const Expr& expr) { return ExprSimplifier().Simplify(expr); } @@ -110,7 +110,7 @@ namespace transform { Pass SimplifyExpr() { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { - return Downcast(SimplifyExpr(f, m)); + return Downcast(SimplifyExpr(f)); }; return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"}); } From b58b998f21b5e3204974e999a8eb2e0982a7d762 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 14 Jul 2020 14:57:59 -0700 Subject: [PATCH 05/13] lint --- src/relay/transforms/simplify_expr.cc | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index f16ec597799e..c0b9e94a670d 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -22,11 +22,12 @@ * \brief A pass for simplifying the Relay expression. */ +#include #include #include #include -#include #include + #include "../op/tensor/transform.h" namespace tvm { @@ -86,13 +87,11 @@ class ExprSimplifier { Map> node_map = args[2]; *rv = simplify_reshape_.callback(pre, post, node_map); }; - callbacks_.push_back(DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), - true)); + callbacks_.push_back( + DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true)); } - Expr Simplify(const Expr& expr) { - return RewritePatterns(callbacks_, expr); - } + Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr); } private: /*! \brief Simplify reshape pattern */ @@ -101,17 +100,13 @@ class ExprSimplifier { Array callbacks_; }; -Expr SimplifyExpr(const Expr& expr) { - return ExprSimplifier().Simplify(expr); -} +Expr SimplifyExpr(const Expr& expr) { return ExprSimplifier().Simplify(expr); } namespace transform { Pass SimplifyExpr() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(SimplifyExpr(f)); - }; + [=](Function f, IRModule m, PassContext pc) { return Downcast(SimplifyExpr(f)); }; return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"}); } From 43a172a69fbdfe755cb571fb1a15a6f905a2bca2 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Tue, 14 Jul 2020 15:01:19 -0700 Subject: [PATCH 06/13] comments --- python/tvm/relay/dataflow_pattern/__init__.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 37a924427b39..0429a78aec5d 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -755,15 +755,11 @@ def rewrite(callbacks, expr: Expr) -> Expr: result : tvm.relay.Expr The Expression with matched subgraphs rewritten by the callbacks. """ - if isinstance(callbacks, DFPatternCallback): - assert callbacks.pattern is not None - tmp = [_DFPatternCallback(callbacks.pattern, callbacks.callback, callbacks.require_type)] - else: - tmp = [] - for callback in callbacks: - assert callback.pattern is not None - tmp.append(_DFPatternCallback(callback.pattern, callback.callback, - callback.require_type)) + callbacks = [callbacks] if isinstance(callbacks, DFPatternCallback) else callbacks + tmp = [] + for callback in callbacks: + assert callback.pattern is not None + tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type)) return ffi.rewrite(tmp, expr) From f28a13fa192bd4ea97462f22ef8bf78e7074f5fc Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 15 Jul 2020 10:33:00 -0700 Subject: [PATCH 07/13] fix bug --- include/tvm/relay/dataflow_matcher.h | 3 +- python/tvm/relay/dataflow_pattern/__init__.py | 9 +++- src/relay/ir/dataflow_matcher.cc | 45 ++++++++++++++++--- src/relay/transforms/simplify_expr.cc | 13 ++++-- 4 files changed, 57 insertions(+), 13 deletions(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 8b72836b3334..89154a1db7fc 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -82,11 +82,12 @@ bool MatchPattern(DFPattern pattern, Expr expr); * * \param callbacks An array of DFPatternCallback Nodes * \param expr The expression to rewrite + * \param mod The module that associates with the expr * * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the * functions inside the callbacks */ -Expr RewritePatterns(Array callbacks, Expr expr); +Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod); /*! * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 0429a78aec5d..51c7addd4305 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -19,6 +19,7 @@ from typing import Callable, Dict, List, Optional import tvm._ffi +from tvm import IRModule from tvm.relay.expr import RelayExpr as Expr from ... import _ffi as tvm_ffi @@ -739,7 +740,7 @@ def __init__(self, pattern, callback, require_type): self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type) -def rewrite(callbacks, expr: Expr) -> Expr: +def rewrite(callbacks, expr: Expr, mod: Optional[IRModule] = None) -> Expr: """ Rewrite expression with the given callbacks. @@ -749,19 +750,23 @@ def rewrite(callbacks, expr: Expr) -> Expr: The input callback or list of callbacks. expr : tvm.relay.Expr The expression to rewrite. + mod : Optional[tvm.IRModule] + The module that associates with the expression. Returns ------- result : tvm.relay.Expr The Expression with matched subgraphs rewritten by the callbacks. """ + if mod is None: + mod = IRModule.from_expr(expr) callbacks = [callbacks] if isinstance(callbacks, DFPatternCallback) else callbacks tmp = [] for callback in callbacks: assert callback.pattern is not None tmp.append(_DFPatternCallback(callback.pattern, callback.callback, callback.require_type)) - return ffi.rewrite(tmp, expr) + return ffi.rewrite(tmp, expr, mod) def partition(pattern: "DFPattern", diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index d16172e6fbdc..14aeede53a93 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -390,6 +390,34 @@ Expr InferType(const Expr& expr) { } } +Expr InferTypeWithModule(const Expr& expr, const IRModule& m) { + IRModule mod(m->functions, m->type_definitions, m->Imports()); + int idx = 0; + std::string gv_name; + do { + std::ostringstream oss; + oss << "_tmp" << idx; + gv_name = oss.str(); + ++idx; + } while (mod->ContainGlobalVar(gv_name)); + GlobalVar gvar(gv_name); + BaseFunc func; + if (expr.as()) { + func = Downcast(expr); + } else { + func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {}); + } + mod->Add(gvar, func); + mod = transform::InferType()(mod); + Expr ret; + if (expr.as()) { + ret = mod->Lookup(gvar); + } else { + ret = mod->Lookup(gvar).as()->body; + } + return ret; +} + bool DFPatternMatcher::VisitDFPattern_(const TypePatternNode* op, const Expr& expr) { auto expr_type = InferType(expr).as()->checked_type(); return (StructuralEqual()(op->type, expr_type)) && VisitDFPattern(op->pattern, expr); @@ -727,7 +755,7 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DFPatternCallback") */ class PatternRewriter : protected MixedModeMutator { public: - PatternRewriter() {} + PatternRewriter(IRModule mod) : mod_(mod) {} /*! \brief Rewrite can take a number of callbacks and will repeatedly rewrite the graph with the * callbacks until it stops changing */ Expr Rewrite(const Array& callbacks, const Expr& pre) { @@ -735,12 +763,15 @@ class PatternRewriter : protected MixedModeMutator { auto last = post; // rewrite the graph until it stops changing to make sure all rewrites are complete int count = 0; + bool changed = false; + static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); + CHECK(structural_equal) << "node.StructuralEqual is not registered."; do { last = post; for (auto callback : callbacks) { callback_ = callback; if (callback_->require_type) { - post = InferType(post); + post = InferTypeWithModule(post, mod_); } auto grouper = PatternGrouper(); groups_ = grouper.GroupMatches(callback_->pattern, post); @@ -749,9 +780,10 @@ class PatternRewriter : protected MixedModeMutator { post = this->VisitExpr(post); count++; } - } while (!StructuralEqual()(last, post) || count >= 100); + changed = (*structural_equal)(last, post, false, true); + } while (!changed || count < 100); if (count >= 100) { - throw("Observed 100 rewrite passes, possible conflicting passes?"); + LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; } return post; } @@ -776,13 +808,14 @@ class PatternRewriter : protected MixedModeMutator { return post; } + IRModule mod_; DFPatternCallback callback_; std::unordered_map groups_; std::unordered_map gid_assignments_; }; -Expr RewritePatterns(Array callbacks, Expr expr) { - return PatternRewriter().Rewrite(callbacks, expr); +Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod) { + return PatternRewriter(mod).Rewrite(callbacks, expr); } TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatterns); diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index c0b9e94a670d..f0df91e455b1 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -80,7 +80,7 @@ class SimplifyReshape { */ class ExprSimplifier { public: - ExprSimplifier() { + ExprSimplifier(IRModule mod) : mod_(mod) { auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) { Expr pre = args[0]; Expr post = args[1]; @@ -91,22 +91,27 @@ class ExprSimplifier { DFPatternCallback(simplify_reshape_.pattern(), PackedFunc(reshape_func), true)); } - Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr); } + Expr Simplify(const Expr& expr) { return RewritePatterns(callbacks_, expr, mod_); } private: + IRModule mod_; /*! \brief Simplify reshape pattern */ SimplifyReshape simplify_reshape_; /*! \brief Callbacks for expr simplification */ Array callbacks_; }; -Expr SimplifyExpr(const Expr& expr) { return ExprSimplifier().Simplify(expr); } +Expr SimplifyExpr(const Expr& expr, const IRModule& mod) { + return ExprSimplifier(mod).Simplify(expr); +} namespace transform { Pass SimplifyExpr() { runtime::TypedPackedFunc pass_func = - [=](Function f, IRModule m, PassContext pc) { return Downcast(SimplifyExpr(f)); }; + [=](Function f, IRModule m, PassContext pc) { + return Downcast(SimplifyExpr(f, m)); + }; return CreateFunctionPass(pass_func, 0, "SimplifyExpr", {"InferType"}); } From d5c1efd4097624b998a161ff45a391d6f6dd2ae2 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 15 Jul 2020 10:37:21 -0700 Subject: [PATCH 08/13] x --- src/relay/transforms/simplify_expr.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/transforms/simplify_expr.cc b/src/relay/transforms/simplify_expr.cc index f0df91e455b1..079b86715a48 100644 --- a/src/relay/transforms/simplify_expr.cc +++ b/src/relay/transforms/simplify_expr.cc @@ -80,7 +80,7 @@ class SimplifyReshape { */ class ExprSimplifier { public: - ExprSimplifier(IRModule mod) : mod_(mod) { + explicit ExprSimplifier(IRModule mod) : mod_(mod) { auto reshape_func = [this](TVMArgs args, TVMRetValue* rv) { Expr pre = args[0]; Expr post = args[1]; From 7bd940146a62d8757f55f322e02c281d9d0497ba Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 15 Jul 2020 19:24:27 +0000 Subject: [PATCH 09/13] fix --- src/relay/ir/dataflow_matcher.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 14aeede53a93..50c05f2923bc 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -763,7 +763,7 @@ class PatternRewriter : protected MixedModeMutator { auto last = post; // rewrite the graph until it stops changing to make sure all rewrites are complete int count = 0; - bool changed = false; + bool equal = true; static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual"); CHECK(structural_equal) << "node.StructuralEqual is not registered."; do { @@ -780,8 +780,8 @@ class PatternRewriter : protected MixedModeMutator { post = this->VisitExpr(post); count++; } - changed = (*structural_equal)(last, post, false, true); - } while (!changed || count < 100); + equal = (*structural_equal)(last, post, false, true); + } while (!equal && count < 100); if (count >= 100) { LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?"; } From 8479dd9d1f0ad8e08f136718cff52fbf414e67db Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 15 Jul 2020 20:13:55 +0000 Subject: [PATCH 10/13] fix --- include/tvm/relay/dataflow_matcher.h | 2 +- python/tvm/relay/dataflow_pattern/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 89154a1db7fc..6639c600e2a1 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -87,7 +87,7 @@ bool MatchPattern(DFPattern pattern, Expr expr); * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the * functions inside the callbacks */ -Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod); +Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod=IRModule()); /*! * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 51c7addd4305..11167d11a849 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -759,7 +759,7 @@ def rewrite(callbacks, expr: Expr, mod: Optional[IRModule] = None) -> Expr: The Expression with matched subgraphs rewritten by the callbacks. """ if mod is None: - mod = IRModule.from_expr(expr) + mod = IRModule() callbacks = [callbacks] if isinstance(callbacks, DFPatternCallback) else callbacks tmp = [] for callback in callbacks: From e0add95fc3922e38041a99bd02d53025d6a10196 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 15 Jul 2020 20:16:49 +0000 Subject: [PATCH 11/13] lint --- include/tvm/relay/dataflow_matcher.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 6639c600e2a1..12e4e3f45fef 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -87,7 +87,7 @@ bool MatchPattern(DFPattern pattern, Expr expr); * \return Return An Expr with every match of the pattern inside the callbacks rewritten by the * functions inside the callbacks */ -Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod=IRModule()); +Expr RewritePatterns(Array callbacks, Expr expr, IRModule mod = IRModule()); /*! * \brief Partition all matches of a DFPattern inside an Expr into separate Function calls From 45b2dcc253fcebfa74c8b0be97ed00c884cb7efb Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 15 Jul 2020 21:25:22 +0000 Subject: [PATCH 12/13] fix warning in doc --- python/tvm/relay/dataflow_pattern/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 11167d11a849..65476f966fea 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -19,11 +19,10 @@ from typing import Callable, Dict, List, Optional import tvm._ffi -from tvm import IRModule from tvm.relay.expr import RelayExpr as Expr from ... import _ffi as tvm_ffi -from ...ir import make_node +from ...ir import make_node, IRModule from ...ir.base import Node from ...runtime import Object from ..op import get @@ -750,7 +749,7 @@ def rewrite(callbacks, expr: Expr, mod: Optional[IRModule] = None) -> Expr: The input callback or list of callbacks. expr : tvm.relay.Expr The expression to rewrite. - mod : Optional[tvm.IRModule] + mod : Optional[tvm.ir.IRModule] The module that associates with the expression. Returns From 4ace1cad1bd57b5e7647089ef53232da709ed385 Mon Sep 17 00:00:00 2001 From: Haichen Shen Date: Wed, 15 Jul 2020 22:18:24 +0000 Subject: [PATCH 13/13] x --- python/tvm/relay/dataflow_pattern/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 65476f966fea..03bdd1952fa1 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -22,7 +22,8 @@ from tvm.relay.expr import RelayExpr as Expr from ... import _ffi as tvm_ffi -from ...ir import make_node, IRModule +from ... import ir as _ir +from ...ir import make_node from ...ir.base import Node from ...runtime import Object from ..op import get @@ -739,7 +740,7 @@ def __init__(self, pattern, callback, require_type): self.__init_handle_by_constructor__(ffi.DFPatternCallback, pattern, callback, require_type) -def rewrite(callbacks, expr: Expr, mod: Optional[IRModule] = None) -> Expr: +def rewrite(callbacks, expr: Expr, mod: Optional[_ir.IRModule] = None) -> Expr: """ Rewrite expression with the given callbacks. @@ -758,7 +759,7 @@ def rewrite(callbacks, expr: Expr, mod: Optional[IRModule] = None) -> Expr: The Expression with matched subgraphs rewritten by the callbacks. """ if mod is None: - mod = IRModule() + mod = _ir.IRModule() callbacks = [callbacks] if isinstance(callbacks, DFPatternCallback) else callbacks tmp = [] for callback in callbacks: