From cf2a83b669a535bc94ab0db6a63502823787cb5a Mon Sep 17 00:00:00 2001 From: Matthew Brookhart Date: Wed, 20 May 2020 17:28:43 -0700 Subject: [PATCH] Label Pattern Partitions (#5627) * Label Pattern Partitions with a default label to prevent nested partitions and an optional user supplied-label * Add node names in topological order to Partitioned attribute * respond to review comments * move partition tag into const in attr namespace --- include/tvm/relay/function.h | 2 + python/tvm/relay/dataflow_pattern/__init__.py | 12 ++- src/relay/ir/dataflow_matcher.cc | 73 +++++++++++++++++-- tests/python/relay/test_dataflow_pattern.py | 55 ++++++++++++-- 4 files changed, 124 insertions(+), 18 deletions(-) diff --git a/include/tvm/relay/function.h b/include/tvm/relay/function.h index ab9111bfe084..d52a66cdadeb 100644 --- a/include/tvm/relay/function.h +++ b/include/tvm/relay/function.h @@ -141,6 +141,8 @@ constexpr const char* kSkipOptimization = "SkipOptimization"; constexpr const char* kComposite = "Composite"; /*! \brief Mark the function to be inlined. */ constexpr const char* kInline = "Inline"; +/*! \brief Indicate the function was created by the Pattern Partitioning Pass. */ +constexpr const char* kPartitionedFromPattern = "PartitionedFromPattern"; } // namespace attr } // namespace relay diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index ca324bc444ec..54fe80a69d3f 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -112,7 +112,7 @@ def match(self, expr: Expr) -> bool: """ return match(self, expr) - def partition(self, expr: Expr) -> bool: + def partition(self, expr: Expr, attrs=None) -> Expr: """ Parition the expression into functions defined by this pattern @@ -120,13 +120,15 @@ def partition(self, expr: Expr) -> bool: ---------- expr : tvm.relay.Expr The expression to match. + attrs : Optional[Dict[str, Object]] + A dictionary of Attribute name/values to add to the paritioned function Returns ------- result : tvm.relay.Expr The Expression with matched subgraphs replaced by function calls to that subgraph """ - return partition(self, expr) + return partition(self, expr, attrs) def dominates(self, parent, path=None): """ @@ -562,7 +564,7 @@ def rewrite(callbacks, expr: Expr) -> Expr: return ffi.rewrite(tmp, expr) -def partition(pattern: DFPattern, expr: Expr) -> Expr: +def partition(pattern: DFPattern, expr: Expr, attrs=None) -> Expr: """ Parition the expression into a series of functions that match the pattern @@ -572,10 +574,12 @@ def partition(pattern: DFPattern, expr: Expr) -> Expr: The pattern to match expr : tvm.relay.Expr The expression to split into functions + expr : Optional[Dict[str, Object]] + A dict of attributes to apply to the partitioned function Returns ------- result : tvm.relay.Expr The Expression with matched subgraphs replaced by function calls to that subgraph """ - return ffi.partition(pattern, expr) + return ffi.partition(pattern, expr, attrs) diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 7c70f324ebf3..0cd3bf709d60 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -386,6 +386,7 @@ class PatternGrouper : protected MixedModeVisitor { Expr root_node; int gid; Map> matched_nodes; + std::string name; Function function; Array args; }; @@ -409,12 +410,17 @@ class PatternGrouper : protected MixedModeVisitor { } protected: + using ExprVisitor::VisitExpr_; void VisitLeaf(const Expr& pre) override { if (matcher_->Match(pattern_, pre)) { CreateGroup(pre); } } - + void VisitExpr_(const FunctionNode* op) override { + if (op->attrs->dict.count(attr::kPartitionedFromPattern) == 0) { + ExprVisitor::VisitExpr_(op); + } + } /* \brief Creates a new set of nodes based on Group inputs, used to create functions and perform * group overlap analysis */ class MatchExtractor : public ExprMutator { @@ -422,6 +428,7 @@ class PatternGrouper : protected MixedModeVisitor { explicit MatchExtractor(const std::unordered_map& inputs) : inputs_(inputs) {} const std::unordered_map& GetMemo() { return this->memo_; } + const std::string& GetName() { return name_; } protected: Expr VisitExpr(const Expr& pre) override { @@ -430,6 +437,46 @@ class PatternGrouper : protected MixedModeVisitor { } return ExprMutator::VisitExpr(pre); } + Expr VisitExpr_(const TupleNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Tuple_"; + return out; + }; + Expr VisitExpr_(const FunctionNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Function"; + return out; + }; + Expr VisitExpr_(const CallNode* call_node) override { + auto out = ExprMutator::VisitExpr_(call_node); + if (auto operation = call_node->op.as()) { + name_ += operation->name + "_"; + } else { + name_ += "Call_"; + } + return out; + }; + Expr VisitExpr_(const LetNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Let_"; + return out; + }; + Expr VisitExpr_(const IfNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "If_"; + return out; + }; + Expr VisitExpr_(const TupleGetItemNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "TupleGetItem" + std::to_string(op->index) + "_"; + return out; + }; + Expr VisitExpr_(const MatchNode* op) override { + auto out = ExprMutator::VisitExpr_(op); + name_ += "Match_"; + return out; + }; + std::string name_; const std::unordered_map inputs_; }; @@ -487,7 +534,7 @@ class PatternGrouper : protected MixedModeVisitor { // Verify the pattern still holds CHECK(DFPatternMatcher(body).Match(pattern_, body)); group.function = Function(params, body, NullValue(), Array()); - + group.name = extractor.GetName(); // Check to make sure we aren't overlapping with another group // The MatchExtractor will create a new graph by replacing nodes that match the inputs of the // pattern with the input FunctionVar* Variables. The resulting memoization map will only @@ -612,10 +659,12 @@ TVM_REGISTER_GLOBAL("relay.dataflow_pattern.rewrite").set_body_typed(RewritePatt */ class PatternPartitioner : protected MixedModeMutator { public: - Expr Partition(const DFPattern& pattern, const Expr& pre) { + Expr Partition(const DFPattern& pattern, const Expr& pre, + const Map& attrs) { auto grouper = PatternGrouper(); groups_ = grouper.GroupMatches(pattern, pre); gid_assignments_ = grouper.GetGIDAssignments(); + attrs_ = attrs; return this->VisitExpr(pre); } @@ -625,7 +674,13 @@ class PatternPartitioner : protected MixedModeMutator { for (size_t i = 0; i < group.args.size(); ++i) { args.push_back(memo_[group.args[i]]); } - return Call(group.function, args); + Function func = WithAttr(group.function, attr::kPartitionedFromPattern, String(group.name)); + if (!attrs_.empty()) { + for (auto kv : attrs_) { + func = WithAttr(std::move(func), kv.first, kv.second); + } + } + return Call(func, args); } Expr DispatchVisitExpr(const Expr& pre) override { @@ -636,15 +691,19 @@ class PatternPartitioner : protected MixedModeMutator { return post; } + Map attrs_; std::vector groups_; std::unordered_map gid_assignments_; }; -Expr PartitionPattern(DFPattern pattern, Expr expr) { - return PatternPartitioner().Partition(pattern, expr); +Expr PartitionPattern(DFPattern pattern, Expr expr, Map attrs) { + return PatternPartitioner().Partition(pattern, expr, attrs); } -TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition").set_body_typed(PartitionPattern); +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.partition") + .set_body_typed([](DFPattern pattern, Expr expr, Map attrs) { + return PartitionPattern(pattern, expr, attrs); + }); } // namespace relay } // namespace tvm diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 41b3d6d997e9..4f3560c85f3e 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -702,6 +702,41 @@ def test_algebraic_simplify(): assert tvm.ir.structural_equal(algebraic_simplify((x + zero * y) / one + (y * one) - zero / x), x + y) +def test_double_partition(): + # Pattern 1 + conv2d_p = is_op('nn.conv2d')(wildcard(), wildcard()) + bias_add_p = is_op("nn.bias_add")(conv2d_p, wildcard()) + relu_p = is_op('nn.relu')(bias_add_p) + + # Graph + x = relay.var('input') + w = relay.var('weight') + b = relay.var('bias') + w2 = relay.var('weight') + b2 = relay.var('bias') + conv2d = relay.op.nn.conv2d(x, w) + bias_add = relay.op.nn.bias_add(conv2d, b) + relu = relay.op.nn.relu(bias_add) + conv2d2 = relay.op.nn.conv2d(relu, w2) + bias_add2 = relay.op.nn.bias_add(conv2d2, b2) + + partitioned = bias_add2 + for pat, label in [(relu_p, "conv_bias_relu"), (bias_add_p, "conv_bias")]: + partitioned = pat.partition(partitioned, {"Composite": label}) + + + inpf = relay.var("input") + weightf = relay.var("weight") + biasf = relay.var("bias") + func0 = relay.Function([inpf, weightf, biasf], relay.op.nn.relu(relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf))).with_attr("Composite", "conv_bias_relu").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_nn.relu_") + inpf = relay.var("input") + weightf = relay.var("weight") + biasf = relay.var("bias") + func1 = relay.Function([inpf, weightf, biasf], relay.op.nn.bias_add(relay.op.nn.conv2d(inpf, weightf), biasf)).with_attr("Composite", "conv_bias").with_attr("PartitionedFromPattern","nn.conv2d_nn.bias_add_") + + expected = func1(func0(x, w, b), w2, b2) + assert tvm.ir.structural_equal(partitioned, expected) + def test_partition_dominator(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) @@ -721,10 +756,10 @@ def generate_diamond(inp, weight): out = generate_diamond(inp*inp, weight*weight) # Check partitioned = diamond.partition(out) - + i = relay.Var("input") w = relay.Var("weight") - f = relay.Function([i, w], generate_diamond(i, w)) + f = relay.Function([i, w], generate_diamond(i, w)).with_attr("PartitionedFromPattern","nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_") assert tvm.ir.structural_equal(partitioned, f(inp*inp, weight*weight)) def test_quadruple_partition_dominator(): @@ -783,10 +818,16 @@ def nested_diamond(inp, weight): ) functions = [] - for f in [classic_diamond, deeper_diamond, single_branch, nested_diamond]: + partition_names = [ + "nn.conv2d_nn.relu_nn.relu_nn.leaky_relu_add_", + "nn.conv2d_nn.relu_nn.relu_tanh_nn.leaky_relu_add_", + "nn.conv2d_nn.relu_nn.relu_tanh_add_", + "nn.conv2d_nn.relu_add_tanh_nn.leaky_relu_add_" + ] + for i, f in enumerate([classic_diamond, deeper_diamond, single_branch, nested_diamond]): inpf = relay.var("input") weightf = relay.var("weight") - functions.append(relay.Function([inpf, weightf], f(inpf, weightf))) + functions.append(relay.Function([inpf, weightf], f(inpf, weightf)).with_attr("PartitionedFromPattern", partition_names[i])) reference = functions[3]( functions[2]( @@ -816,7 +857,7 @@ def test_parition_batchnorm(): betaf = relay.var('betaf') gammaf = relay.var('gammaf') # Put the arguments in toplogological order for the reference - f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)) + f = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") partitioned = BatchnormCallback().pattern.partition(BN) assert tvm.ir.structural_equal(partitioned, f(gamma, x, mean, var, beta)) @@ -836,14 +877,14 @@ def test_parition_double_batchnorm(): meanf = relay.var('meanf') betaf = relay.var('betaf') gammaf = relay.var('gammaf') - f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)) + f1 = relay.Function([gammaf, xf, meanf, varf, betaf], get_BN(xf, varf, meanf, betaf, gammaf)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") # The paritioner doesn't replace duplicates, so we use two copies of the function xf2 = relay.var('xf2') varf2 = relay.var('varf2') meanf2 = relay.var('meanf2') betaf2 = relay.var('betaf2') gammaf2 = relay.var('gammaf2') - f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)) + f2 = relay.Function([gammaf2, xf2, meanf2, varf2, betaf2], get_BN(xf2, varf2, meanf2, betaf2, gammaf2)).with_attr("PartitionedFromPattern","subtract_multiply_add_sqrt_divide_add_") partitioned = BatchnormCallback().pattern.partition(BN2) reference = f2(gamma, f1(gamma, x, mean, var, beta), mean, var, beta)