diff --git a/include/tvm/relay/dataflow_matcher.h b/include/tvm/relay/dataflow_matcher.h index 58aa6400b650..517582bda01d 100644 --- a/include/tvm/relay/dataflow_matcher.h +++ b/include/tvm/relay/dataflow_matcher.h @@ -27,6 +27,7 @@ #include #include +#include #include #include @@ -87,10 +88,14 @@ Expr RewritePatterns(Array callbacks, Expr expr); * * \param pattern The pattern to match * \param expr The expression to patition + * \param attrs A set of parameter names and values to apply to the partitioned function + * \param check A callback function for checking more complicated properties of the matched + * expressions, returns true if the match is accepted and false otherwise * * \return Return the paritioned Expr. */ -Expr PartitionPattern(DFPattern pattern, Expr expr); +Expr PartitionPattern(DFPattern pattern, Expr expr, Map attrs, + PackedFunc check); } // namespace relay } // namespace tvm diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index 2582894c14eb..f8be3e2b28ce 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -109,7 +109,7 @@ def match(self, expr: Expr) -> bool: """ return match(self, expr) - def partition(self, expr: Expr, attrs=None) -> Expr: + def partition(self, expr: Expr, attrs=None, check=lambda x: True) -> Expr: """ Parition the expression into functions defined by this pattern @@ -119,13 +119,16 @@ def partition(self, expr: Expr, attrs=None) -> Expr: The expression to match. attrs : Optional[Dict[str, Object]] A dictionary of Attribute name/values to add to the paritioned function + check : Function + A function to perform more complicated checks on the matched expression. + Returns true if partitioning should proceed, false otherwise. Returns ------- result : tvm.relay.Expr The Expression with matched subgraphs replaced by function calls to that subgraph """ - return partition(self, expr, attrs) + return partition(self, expr, attrs, check) def dominates(self, parent, path=None): """ @@ -561,7 +564,7 @@ def rewrite(callbacks, expr: Expr) -> Expr: return ffi.rewrite(tmp, expr) -def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr: +def partition(pattern: DFPattern, expr: Expr, attrs=None, check=lambda x: True) -> Expr: """ Parition the expression into a series of functions that match the pattern @@ -571,12 +574,15 @@ def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr: The pattern to match expr : tvm.relay.Expr The expression to split into functions - expr : Optional[Dict[str, Object]] + attrs : Optional[Dict[str, Object]] A dict of attributes to apply to the partitioned function + check : Function + A function to perform more complicated checks on the matched expression. + Returns true if partitioning should proceed, false otherwise. Returns ------- result : tvm.relay.Expr The Expression with matched subgraphs replaced by function calls to that subgraph """ - return ffi.partition(pattern, expr, attrs) + return ffi.partition(pattern, expr, attrs, check) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 4bb2b0ba7249..980935c34c11 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -693,11 +693,12 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatt class PatternPartitioner : protected MixedModeMutator { public: Expr Partition(const DFPattern& pattern, const Expr& pre, - const Map& attrs) { + const Map& attrs, PackedFunc check) { auto grouper = PatternGrouper(); groups_ = grouper.GroupMatches(pattern, pre); gid_assignments_ = grouper.GetGIDAssignments(); attrs_ = attrs; + check_ = check; return this->VisitExpr(pre); } @@ -718,7 +719,8 @@ class PatternPartitioner : protected MixedModeMutator { Expr DispatchVisitExpr(const Expr& pre) override { auto post = MixedModeMutator::DispatchVisitExpr(pre); - if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node) { + if (gid_assignments_.count(pre) && pre == groups_[gid_assignments_[pre]].root_node && + static_cast(check_(pre))) { post = RewritePartition(groups_[gid_assignments_[pre]]); } return post; @@ -727,16 +729,17 @@ class PatternPartitioner : protected MixedModeMutator { Map attrs_; std::vector groups_; std::unordered_map gid_assignments_; + PackedFunc check_; }; -Expr PartitionPattern(DFPattern pattern, Expr expr, Map attrs) { - return PatternPartitioner().Partition(pattern, expr, attrs); +Expr PartitionPattern(DFPattern pattern, Expr expr, Map attrs, + PackedFunc check) { + return PatternPartitioner().Partition(pattern, expr, attrs, check); } TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition") - .set_body_typed([](DFPattern pattern, Expr expr, Map attrs) { - return PartitionPattern(pattern, expr, attrs); - }); + .set_body_typed([](DFPattern pattern, Expr expr, Map attrs, + PackedFunc check) { return PartitionPattern(pattern, expr, attrs, check); }); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 411ef0f49265..3a605e4da94b 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -17,6 +17,7 @@ import tvm from tvm import relay from tvm.relay.dataflow_pattern import * +from tvm.relay.testing import run_opt_pass import numpy as np # NB: 1 corresponds to the C++ enum that specicfies this @@ -880,7 +881,7 @@ def nested_diamond(inp, weight): def get_BN(x, var, mean, beta, gamma, eps = 1e-5): return gamma * (x - mean)/relay.op.sqrt(var + relay.const(eps)) + beta -def test_parition_batchnorm(): +def test_partition_batchnorm(): x = relay.var('x') var = relay.var('var') mean = relay.var('mean') @@ -900,7 +901,7 @@ def test_parition_batchnorm(): partitioned = BatchnormCallback().pattern.partition(BN) assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta)) -def test_parition_double_batchnorm(): +def test_partition_double_batchnorm(): x = relay.var('x') var = relay.var('var') mean = relay.var('mean') @@ -916,7 +917,7 @@ def test_parition_double_batchnorm(): betaf = relay.var('betaf') gammaf = relay.var('gammaf') f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") - # The paritioner doesn't replace duplicates, so we use two copies of the function + # The partitioner doesn't replace duplicates, so we use two copies of the function xf2 = relay.var('xf2') varf2 = relay.var('varf2') meanf2 = relay.var('meanf2') @@ -928,6 +929,58 @@ def test_parition_double_batchnorm(): reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta) assert tvm.ir.structural_equal(partitioned, reference) +def test_partition_check(): + pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard())) + def check(pre): + return pre.args[0].attrs.data_layout == "NCHW" + + x = relay.var('input') + w = relay.var('weight') + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + + xf = relay.var('input') + wf = relay.var('weight') + conv2df = relay.op.nn.conv2d(xf, wf) + reluf = relay.op.nn.relu(conv2df) + func = relay.Function([xf, wf], reluf).with_attr("PartitionedFromPattern", "nn.conv2d_nn.relu_") + + reference = func(x, w) + partitioned = pattern.partition(relu, check=check) + assert tvm.ir.structural_equal(partitioned, reference) + + conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC") + relu = relay.op.nn.relu(conv2d) + assert relu == pattern.partition(relu, check=check) + +def test_partition_check_types(): + pattern = is_op("nn.relu")(is_op("nn.conv2d")(wildcard(), wildcard())) + def check(pre): + conv = pre.args[0] + return (conv.attrs.data_layout == "NCHW") and bool(conv.checked_type.shape[0] == 1) + + x = relay.var('input', shape=(1, 10, 10, 10)) + w = relay.var('weight', shape=(10, 10, 3, 3)) + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + relu = run_opt_pass(relu, relay.transform.InferType()) + + partitioned = pattern.partition(relu, check=check) + assert partitioned.op.attrs["PartitionedFromPattern"] == "nn.conv2d_nn.relu_" + + conv2d = relay.op.nn.conv2d(x, w, data_layout="NHWC") + relu = relay.op.nn.relu(conv2d) + relu = run_opt_pass(relu, relay.transform.InferType()) + assert relu == pattern.partition(relu, check=check) + + x = relay.var('input', shape=(2, 10, 10, 10)) + w = relay.var('weight', shape=(10, 10, 3, 3)) + conv2d = relay.op.nn.conv2d(x, w) + relu = relay.op.nn.relu(conv2d) + relu = run_opt_pass(relu, relay.transform.InferType()) + assert relu == pattern.partition(relu, check=check) + + if __name__ == "__main__": test_match_op() test_no_match_op() @@ -957,6 +1010,8 @@ def test_parition_double_batchnorm(): test_algebraic_simplify() test_partition_dominator() test_quadruple_partition_dominator() - test_parition_batchnorm() - test_parition_double_batchnorm() + test_partition_batchnorm() + test_partition_double_batchnorm() + test_partition_check() + test_partition_check_types()