From 9bcf0bc107b6aacf901617a74595b34c0fa7c0df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=94=B5=E7=BA=BF=E6=9D=86?= <446100240@qq.com> Date: Sat, 5 Aug 2023 06:27:52 +0800 Subject: [PATCH] [Relay] add redirecting operation to dataflow pattern graph (#15392) * Add redirecting operation to dataflow pattern graph * Lint --- include/tvm/relay/dataflow_pattern.h | 6 ++ python/tvm/relay/dataflow_pattern/__init__.py | 13 +++++ src/relay/ir/dataflow_matcher.cc | 6 +- src/relay/ir/dataflow_pattern.cc | 10 ++++ src/relay/ir/dataflow_pattern_functor.cc | 6 +- src/relay/ir/indexed_graph.cc | 7 ++- tests/python/relay/test_dataflow_pattern.py | 56 +++++++++++++++++++ 7 files changed, 101 insertions(+), 3 deletions(-) diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 8c30a0df9fae..040372db3533 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -362,6 +362,10 @@ class WildcardPatternNode : public DFPatternNode { public: void VisitAttrs(tvm::AttrVisitor* v) {} + /*! \brief If the wildcard is redirected, then pattern is not nullptr, and the wildcard + * redirects to the pattern. */ + Optional pattern{nullptr}; + static constexpr const char* _type_key = "relay.dataflow_pattern.WildcardPattern"; TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode); }; @@ -372,6 +376,8 @@ class WildcardPatternNode : public DFPatternNode { class WildcardPattern : public DFPattern { public: TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode); + + void redirect_to(DFPattern pat) const; }; class TypePattern; diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 96950a2e4749..76a24c048cf9 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -722,6 +722,19 @@ class WildcardPattern(DFPattern): def __init__(self): self.__init_handle_by_constructor__(ffi.WildcardPattern) + def redirect_to( + self, + pat: "DFPattern", + ): + """Redirect the WildcardPattern to another pattern + + Parameters + ---------- + pat: relay.dataflow_pattern.DFPattern + The pattern that wildcard is redirected to. + """ + ffi.WildcardPattern_redirect_to(self, pat) + @register_df_node class TypePattern(DFPattern): diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 249f4ccf7a44..ee585446cb26 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -488,7 +488,11 @@ bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr } bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) { - return true; + if (op->pattern) { + return VisitDFPattern(op->pattern.value(), expr); + } else { + return true; + } } bool MatchPattern(DFPattern pattern, Expr expr) { diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index c141ca51ef4f..637cb0665d38 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -344,8 +344,18 @@ TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable) << ")"; }); +void WildcardPattern::redirect_to(DFPattern pat) const { + WildcardPatternNode* ptr = static_cast(get_mutable()); + ptr->pattern = pat; +} + TVM_REGISTER_NODE_TYPE(WildcardPatternNode); +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern_redirect_to") + .set_body_typed([](WildcardPattern wildcard, DFPattern pat) { + return wildcard.redirect_to(pat); + }); + TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]() { auto w = WildcardPattern(make_object()); return w; diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc index 290f72df1deb..76b3fe068e45 100644 --- a/src/relay/ir/dataflow_pattern_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -105,7 +105,11 @@ void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {} void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {} -void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) { + if (op->pattern) { + VisitDFPattern(op->pattern.value()); + } +} } // namespace relay } // namespace tvm diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 044884f87eb4..f10920769d1f 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -537,7 +537,12 @@ std::unique_ptr> CreateIndexedGraph(const DFPattern& pat void VisitDFPattern_(const VarPatternNode* op) override {} - void VisitDFPattern_(const WildcardPatternNode* op) override {} + void VisitDFPattern_(const WildcardPatternNode* op) override { + if (op->pattern) { + auto node = graph_->item_to_node(GetRef(op)); + AddOutput(op->pattern.value(), node); + } + } std::unique_ptr> graph_; }; diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index c4a83735cee9..3950c02c08a4 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -1995,5 +1995,61 @@ def test_partition_parallel_branch_with_same_input(): assert tvm.ir.structural_equal(partitioned, reference) +def test_rewrite_with_pattern_recursion(): + data = relay.var("data", relay.TensorType((2, 8), "float32")) + dense_weight = relay.const(np.zeros((4, 8))) + feat = relay.nn.dense(data, dense_weight) + feat = relay.cast(feat, "float32") + feat = relay.cast(feat, "float32") + feat = relay.cast(feat, "float32") + feat = relay.cast(feat, "float32") + feat = relay.cast(feat, "float32") + oup = relay.cast(feat, "float32") + + expected = relay.nn.relu(oup) + + class TheRewrite(DFPatternCallback): + def __init__(self, pattern): + super(TheRewrite, self).__init__(rewrite_once=True) + self.pattern = pattern + + def callback(self, pre, post, node_map): + return relay.nn.relu(post) + + def test_reset_call_args(): + dense_pattern = is_op("nn.dense")(wildcard(), wildcard()) + wildcard_redirect = wildcard() + the_pattern = is_op("cast")(wildcard_redirect) + the_pattern2 = the_pattern | dense_pattern + wildcard_redirect.redirect_to(the_pattern2) + + actual = rewrite(TheRewrite(the_pattern), oup) + tvm.ir.assert_structural_equal(actual, expected) + + def test_reset_alt_left(): + dense_pattern = is_op("nn.dense")(wildcard(), wildcard()) + wildcard_redirect = wildcard() + or_pattern = wildcard_redirect | dense_pattern + the_pattern = is_op("cast")(or_pattern) + wildcard_redirect.redirect_to(the_pattern) + + actual = rewrite(TheRewrite(the_pattern), oup) + tvm.ir.assert_structural_equal(actual, expected) + + def test_reset_alt_right(): + dense_pattern = is_op("nn.dense")(wildcard(), wildcard()) + wildcard_redirect = wildcard() + or_pattern = dense_pattern | wildcard_redirect + the_pattern = is_op("cast")(or_pattern) + wildcard_redirect.redirect_to(the_pattern) + + actual = rewrite(TheRewrite(the_pattern), oup) + tvm.ir.assert_structural_equal(actual, expected) + + test_reset_call_args() + test_reset_alt_left() + test_reset_alt_right() + + if __name__ == "__main__": tvm.testing.main()