From ad17ededc31a6d7bf91059c30d0b265b8621fe4a Mon Sep 17 00:00:00 2001 From: mbaret <55580676+mbaret@users.noreply.github.com> Date: Wed, 8 Apr 2020 04:12:15 +0100 Subject: [PATCH] [RELAY][BYOC] Add support for composite functions in BYOC (#5261) * [RELAY] Add 'check' functions to MergeComposite Currently, MergeComposite can only perform structural matches. This patch introduces the ability to specify a 'check' function alongside the pattern which can include custom logic to determine whether an extracted pattern should be merged. For example, if you only want to merge 'NHWC' convolutions, you can specify a 'check' function which queries the data_layout value of the extracted pattern (see the test). Change-Id: I9337ce39f10997051a286d888be38ed0d410d340 * [RELAY] Reformat merge_composite.cc Run clang-format on merge_composite.cc Change-Id: I1736bff798cc6d93e57519b08ab3362869098779 * [RELAY][BYOC] Support composite functions in AnnotateTarget This patch introduces support to annotate composite functions in the AnnotateTarget pass. In order for a composite function to be annotated, you should name it according to the style: {codegen}.{name} eg. dnnl.add_relu Change-Id: I74d6c0b506153d866f6d1feb203b32dad59f2871 --- python/tvm/relay/transform/transform.py | 17 +++- src/relay/transforms/annotate_target.cc | 40 +++++++-- src/relay/transforms/merge_composite.cc | 90 +++++++++---------- tests/python/relay/test_annotate_target.py | 46 ++++++++++ .../python/relay/test_pass_merge_composite.py | 38 ++++++++ 5 files changed, 173 insertions(+), 58 deletions(-) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index 41aa040952770..ce4ac79a88d05 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -378,9 +378,12 @@ def MergeComposite(pattern_table): Parameters ---------- pattern_table : list(tuple) - A list of (pattern_name, pattern) tuples. + A list of (pattern_name, pattern, check) tuples. The order of the patterns in the list will determine the order of priority in which they are matched. + 'check' is a function to check whether an extracted pattern matches. + It can be implemented by pattern writer but if not specified it will + always return True. Returns ------- @@ -390,11 +393,19 @@ def MergeComposite(pattern_table): """ pattern_names = [] patterns = [] - for pattern_name, pattern in pattern_table: + checks = [] + for tup in pattern_table: + if len(tup) == 2: + pattern_name, pattern = tup + check = lambda extract: True + elif len(tup) == 3: + pattern_name, pattern, check = tup + pattern_names.append(pattern_name) patterns.append(pattern) + checks.append(check) - return _ffi_api.MergeComposite(pattern_names, patterns) + return _ffi_api.MergeComposite(pattern_names, patterns, *checks) def MergeCompilerRegions(): diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index b546f05b46e4e..c3d34cb9ab7cd 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -49,10 +49,24 @@ class AnnotateTargetWrapper : public ExprMutator { if (expr->IsInstance()) { Call call = Downcast(expr); auto fannotate = Op::GetAttr("target." + target_); - Op op = Downcast(call->op); - CHECK(op.defined()); - if (fannotate.count(op)) { - return fannotate[op](call->attrs, call->args); + if (call->op->IsInstance()) { + Op op = Downcast(call->op); + CHECK(op.defined()); + if (fannotate.count(op)) { + return fannotate[op](call->attrs, call->args); + } + } else if (call->op->IsInstance()) { + // handle composite functions + Function func = Downcast(call->op); + CHECK(func.defined()); + auto comp_name = func->GetAttr(attr::kComposite); + if (comp_name.defined()) { + size_t i = comp_name->value.find('.'); + if (i != std::string::npos) { + std::string target = comp_name->value.substr(0, i); + if (target == target_) return true; + } + } } } if (expr->IsInstance()) { @@ -77,7 +91,6 @@ class AnnotateTargetWrapper : public ExprMutator { } Expr VisitExpr_(const CallNode* cn) { - // TODO(@zhiics, @comaniac) Handle composite functions. auto new_e = ExprMutator::VisitExpr_(cn); Call call = Downcast(new_e); @@ -130,13 +143,22 @@ class AnnotateTargetWrapper : public ExprMutator { } } - Expr VisitExpr_(const FunctionNode* op) { - auto new_e = ExprMutator::VisitExpr_(op); + Expr VisitExpr_(const FunctionNode* fn) { + Function func; + Expr new_body; + // don't step into composite functions + if (fn->GetAttr(attr::kComposite).defined()) { + func = GetRef(fn); + new_body = func->body; + } else { + auto new_e = ExprMutator::VisitExpr_(fn); + func = Downcast(new_e); + new_body = InsertEnd(func->body); + } - auto func = Downcast(new_e); return Function( func->params, - InsertEnd(func->body), + new_body, func->ret_type, func->type_params, func->attrs); diff --git a/src/relay/transforms/merge_composite.cc b/src/relay/transforms/merge_composite.cc index e26ff402c3cd6..35b93dced90d0 100644 --- a/src/relay/transforms/merge_composite.cc +++ b/src/relay/transforms/merge_composite.cc @@ -25,11 +25,11 @@ * Relay operators map to a single external operator. */ -#include #include #include #include #include +#include namespace tvm { namespace relay { @@ -37,11 +37,12 @@ namespace merge_composite { class MergeCompositeWrapper : public ExprMutator { public: - explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern) - : pattern_name_(pattern_name), pattern_(pattern) {} + explicit MergeCompositeWrapper(const std::string& pattern_name, const Expr& pattern, + const PackedFunc& check) + : pattern_name_(pattern_name), pattern_(pattern), check_(check) {} Expr ExtractPattern(const Var& pattern, const Expr& root, - Map>* var_map) { + Map>* var_map) { if (var_map->find(pattern->name_hint()) == var_map->end()) { // if we haven't encountered this var yet, make a new free var and associate // it with the value at 'root' @@ -62,12 +63,12 @@ class MergeCompositeWrapper : public ExprMutator { } Expr ExtractPattern(const Constant& pattern, const Expr& root, - Map>* var_map) { + Map>* var_map) { return root; } Expr ExtractPattern(const TupleGetItem& pattern, const Expr& root, - Map>* var_map, Map* call_map) { + Map>* var_map, Map* call_map) { if (!root->IsInstance()) { return Expr(); } @@ -75,14 +76,12 @@ class MergeCompositeWrapper : public ExprMutator { if (pattern->index != root_node->index) { return Expr(); } - if (pattern->tuple->IsInstance() && - root_node->tuple->IsInstance()) { + if (pattern->tuple->IsInstance() && root_node->tuple->IsInstance()) { Expr new_arg; if (call_map->find(pattern->tuple) != call_map->end()) { new_arg = (*call_map)[pattern->tuple]; } else { - new_arg = ExtractPattern(Downcast(pattern->tuple), - Downcast(root_node->tuple), + new_arg = ExtractPattern(Downcast(pattern->tuple), Downcast(root_node->tuple), var_map, call_map); call_map->Set(pattern->tuple, new_arg); } @@ -104,20 +103,18 @@ class MergeCompositeWrapper : public ExprMutator { * and free variables. The free variables indicate where the pattern can 'attach' in your * graph. This function takes the final call node of the pattern and the call node currently * being traversed in the Relay graph. It traverses through the pattern in lockstep with call node - * from the graph (referred to as the 'root' node here) to check they're identical. If at any point - * they differ, an empty expression is returned to signify the extract failed. If a free var is - * reached in the pattern, the corresponding value in the root is associated with the name of the - * free var (via the var_map) so that when we construct the composite function, the inputs match - * up correctly with the rest of the graph. The return value of this function when successful is - * a new Relay expression ready to be wrapped into a composite function. + * from the graph (referred to as the 'root' node here) to check they're identical. If at any + * point they differ, an empty expression is returned to signify the extract failed. If a free var + * is reached in the pattern, the corresponding value in the root is associated with the name of + * the free var (via the var_map) so that when we construct the composite function, the inputs + * match up correctly with the rest of the graph. The return value of this function when + * successful is a new Relay expression ready to be wrapped into a composite function. */ - Expr ExtractPattern(const Call& pattern, const Call& root, - Map>* var_map, Map* call_map) { + Expr ExtractPattern(const Call& pattern, const Call& root, Map>* var_map, + Map* call_map) { // check to make sure both calls are to operators (not functions) - if (!pattern->op->IsInstance() || !root->op->IsInstance()) - return Expr(); - if (pattern->op.as()->name != root->op.as()->name) - return Expr(); + if (!pattern->op->IsInstance() || !root->op->IsInstance()) return Expr(); + if (pattern->op.as()->name != root->op.as()->name) return Expr(); unsigned int i = 0; Array new_args; @@ -133,27 +130,20 @@ class MergeCompositeWrapper : public ExprMutator { return Expr(); } // if it's a call node, recursively call this function - new_arg = ExtractPattern(Downcast(arg), - Downcast(root->args[i]), - var_map, call_map); + new_arg = + ExtractPattern(Downcast(arg), Downcast(root->args[i]), var_map, call_map); call_map->Set(arg, new_arg); } } else if (arg->IsInstance()) { // if there's a var in the pattern, it must be a free var // so call the function to update the var_map - new_arg = ExtractPattern(Downcast(arg), - root->args[i], - var_map); + new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map); } else if (arg->IsInstance()) { // if there's a constant, simply get the corresponding // value of the constant from the root - new_arg = ExtractPattern(Downcast(arg), - root->args[i], - var_map); + new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map); } else if (arg->IsInstance()) { - new_arg = ExtractPattern(Downcast(arg), - root->args[i], - var_map, call_map); + new_arg = ExtractPattern(Downcast(arg), root->args[i], var_map, call_map); } if (!new_arg.defined()) { return Expr(); @@ -169,8 +159,7 @@ class MergeCompositeWrapper : public ExprMutator { if (call->op->IsInstance()) { Function func = Downcast(call->op); CHECK(func.defined()); - const auto name_node = - func->GetAttr(attr::kComposite); + const auto name_node = func->GetAttr(attr::kComposite); // don't step into existing composite functions if (name_node.defined() && name_node->value != "") { tvm::Array new_args; @@ -184,8 +173,7 @@ class MergeCompositeWrapper : public ExprMutator { Expr expr = ExprMutator::VisitExpr_(cn); call = Downcast(expr); - if (!call->op->IsInstance()) - return std::move(call); + if (!call->op->IsInstance()) return std::move(call); // only call patterns are supported Call pattern = Downcast(pattern_); @@ -193,7 +181,7 @@ class MergeCompositeWrapper : public ExprMutator { Map> args_map; Map call_map; auto extract = ExtractPattern(pattern, call, &args_map, &call_map); - if (extract.defined()) { + if (extract.defined() && static_cast(check_(extract))) { auto free_vars = FreeVars(extract); // make the composite function auto f = Function(free_vars, extract, call->checked_type_, {}, DictAttrs()); @@ -215,17 +203,20 @@ class MergeCompositeWrapper : public ExprMutator { std::string pattern_name_; /*! \brief The pattern to match */ Expr pattern_; + /*! \brief The function to check whether an extract is supported */ + PackedFunc check_; }; -Expr MergeComposite(const Expr& expr, - const Array& pattern_names, const Array& patterns) { +Expr MergeComposite(const Expr& expr, const Array& pattern_names, + const Array& patterns, const std::vector& checks) { CHECK_EQ(pattern_names.size(), patterns.size()); Expr merged_expr = expr; // merge the patterns one-by-one in order for (size_t i = 0; i < patterns.size(); i++) { std::string pattern_name = pattern_names[i]->value; Expr pattern = patterns[i]; - merged_expr = MergeCompositeWrapper(pattern_name, pattern).Mutate(merged_expr); + PackedFunc check = checks[i]; + merged_expr = MergeCompositeWrapper(pattern_name, pattern, check).Mutate(merged_expr); } return merged_expr; } @@ -235,18 +226,25 @@ Expr MergeComposite(const Expr& expr, namespace transform { Pass MergeComposite(const tvm::Array& pattern_names, - const tvm::Array& patterns) { + const tvm::Array& patterns, const std::vector& checks) { runtime::TypedPackedFunc pass_func = [=](Function f, IRModule m, PassContext pc) { return Downcast( - relay::merge_composite::MergeComposite(f, pattern_names, patterns)); + relay::merge_composite::MergeComposite(f, pattern_names, patterns, checks)); }; auto func_pass = CreateFunctionPass(pass_func, 0, "MergeComposite", {}); return func_pass; } -TVM_REGISTER_GLOBAL("relay._transform.MergeComposite") -.set_body_typed(MergeComposite); +TVM_REGISTER_GLOBAL("relay._transform.MergeComposite").set_body([](TVMArgs args, TVMRetValue* rv) { + tvm::Array pattern_names = args[0]; + tvm::Array patterns = args[1]; + std::vector checks; + for (int i = 2; i < args.size(); i++) { + checks.push_back(args[i]); + } + *rv = MergeComposite(pattern_names, patterns, checks); +}); } // namespace transform diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 87cf7616e232e..0a2abd73d5eb2 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -219,7 +219,53 @@ def after(): assert tvm.ir.structural_equal(expected, result) +def test_composite_function(): + def before(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + relu_node = relay.nn.relu(add_node) + add_relu = relay.Function([in_1, in_2], relu_node) + add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + + # merged function + r = relay.Call(add_relu, [a, b]) + f = relay.Function([a, b], r) + mod = tvm.IRModule.from_expr(f) + return mod + + def after(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + relu_node = relay.nn.relu(add_node) + add_relu = relay.Function([in_1, in_2], relu_node) + add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + + # merged function + cb_1 = relay.annotation.compiler_begin(a, "test") + cb_2 = relay.annotation.compiler_begin(b, "test") + r = relay.Call(add_relu, [cb_1, cb_2]) + ce_1 = relay.annotation.compiler_end(r, "test") + f = relay.Function([a, b], ce_1) + mod = tvm.IRModule.from_expr(f) + return mod + + result = transform.AnnotateTarget("test")(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + if __name__ == "__main__": test_multiple_ends() test_extern_dnnl() test_extern_dnnl_mobilenet() + test_composite_function() diff --git a/tests/python/relay/test_pass_merge_composite.py b/tests/python/relay/test_pass_merge_composite.py index 3c70cf237c941..110d855216e4d 100644 --- a/tests/python/relay/test_pass_merge_composite.py +++ b/tests/python/relay/test_pass_merge_composite.py @@ -732,6 +732,43 @@ def expected(): assert tvm.ir.structural_equal(result, expected, map_free_vars=True) +def test_pattern_with_check(): + def before(): + x = relay.var('x', shape=(1, 10, 10, 10)) + w = relay.var('w', shape=(10, 10, 3, 3)) + b = relay.var('b', shape=(8,)) + conv = relay.nn.conv2d(x, + w, + kernel_size=(3, 3), + kernel_layout="OIHW", + data_layout="NHWC") + bias = relay.nn.bias_add(conv, b) + relu = relay.nn.relu(bias) + return relay.Function([x, w, b], relu) + + def _check_true(extract): + conv = extract.args[0].args[0] + return conv.attrs.data_layout == "NHWC" + + def _check_false(extract): + conv = extract.args[0].args[0] + return conv.attrs.data_layout == "NCHW" + + pattern_table_true = [ + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_true) + ] + pattern_table_false = [ + ("conv_bias_relu", make_conv_bias_relu_pattern(), _check_false) + ] + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_false)) + expected = run_opt_pass(before(), relay.transform.InferType()) + assert tvm.ir.structural_equal(result, expected, map_free_vars=True) + + result = run_opt_pass(before(), relay.transform.MergeComposite(pattern_table_true)) + assert result.body.op.attrs["Composite"] == "conv_bias_relu" + + if __name__ == "__main__": test_simple_merge() test_branch_merge() @@ -741,3 +778,4 @@ def expected(): test_multiple_input_subgraphs() test_reuse_call_merge() test_tuple_get_item_merge() + test_pattern_with_check()