From dd6f722aa632a94ab54e0ea98665419136d68c20 Mon Sep 17 00:00:00 2001 From: Matthew Barrett Date: Tue, 7 Apr 2020 10:25:36 +0100 Subject: [PATCH] [RELAY][BYOC] Support composite functions in AnnotateTarget This patch introduces support to annotate composite functions in the AnnotateTarget pass. In order for a composite function to be annotated, you should name it according to the style: {codegen}.{name} eg. dnnl.add_relu Change-Id: I73d6c0b506153d866f6d1feb203b32dad59f2871 --- src/relay/transforms/annotate_target.cc | 40 ++++++++++++++----- tests/python/relay/test_annotate_target.py | 46 ++++++++++++++++++++++ 2 files changed, 77 insertions(+), 9 deletions(-) diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index b546f05b46e4e..c3d34cb9ab7cd 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -49,10 +49,24 @@ class AnnotateTargetWrapper : public ExprMutator { if (expr->IsInstance()) { Call call = Downcast(expr); auto fannotate = Op::GetAttr("target." + target_); - Op op = Downcast(call->op); - CHECK(op.defined()); - if (fannotate.count(op)) { - return fannotate[op](call->attrs, call->args); + if (call->op->IsInstance()) { + Op op = Downcast(call->op); + CHECK(op.defined()); + if (fannotate.count(op)) { + return fannotate[op](call->attrs, call->args); + } + } else if (call->op->IsInstance()) { + // handle composite functions + Function func = Downcast(call->op); + CHECK(func.defined()); + auto comp_name = func->GetAttr(attr::kComposite); + if (comp_name.defined()) { + size_t i = comp_name->value.find('.'); + if (i != std::string::npos) { + std::string target = comp_name->value.substr(0, i); + if (target == target_) return true; + } + } } } if (expr->IsInstance()) { @@ -77,7 +91,6 @@ class AnnotateTargetWrapper : public ExprMutator { } Expr VisitExpr_(const CallNode* cn) { - // TODO(@zhiics, @comaniac) Handle composite functions. auto new_e = ExprMutator::VisitExpr_(cn); Call call = Downcast(new_e); @@ -130,13 +143,22 @@ class AnnotateTargetWrapper : public ExprMutator { } } - Expr VisitExpr_(const FunctionNode* op) { - auto new_e = ExprMutator::VisitExpr_(op); + Expr VisitExpr_(const FunctionNode* fn) { + Function func; + Expr new_body; + // don't step into composite functions + if (fn->GetAttr(attr::kComposite).defined()) { + func = GetRef(fn); + new_body = func->body; + } else { + auto new_e = ExprMutator::VisitExpr_(fn); + func = Downcast(new_e); + new_body = InsertEnd(func->body); + } - auto func = Downcast(new_e); return Function( func->params, - InsertEnd(func->body), + new_body, func->ret_type, func->type_params, func->attrs); diff --git a/tests/python/relay/test_annotate_target.py b/tests/python/relay/test_annotate_target.py index 87cf7616e232e..0a2abd73d5eb2 100644 --- a/tests/python/relay/test_annotate_target.py +++ b/tests/python/relay/test_annotate_target.py @@ -219,7 +219,53 @@ def after(): assert tvm.ir.structural_equal(expected, result) +def test_composite_function(): + def before(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + relu_node = relay.nn.relu(add_node) + add_relu = relay.Function([in_1, in_2], relu_node) + add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + + # merged function + r = relay.Call(add_relu, [a, b]) + f = relay.Function([a, b], r) + mod = tvm.IRModule.from_expr(f) + return mod + + def after(): + a = relay.var('a', shape=(10, 10)) + b = relay.var('b', shape=(10, 10)) + + # add_relu function + in_1 = relay.var('in_1', shape=(10, 10)) + in_2 = relay.var('in_2', shape=(10, 10)) + add_node = relay.add(in_1, in_2) + relu_node = relay.nn.relu(add_node) + add_relu = relay.Function([in_1, in_2], relu_node) + add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu")) + + # merged function + cb_1 = relay.annotation.compiler_begin(a, "test") + cb_2 = relay.annotation.compiler_begin(b, "test") + r = relay.Call(add_relu, [cb_1, cb_2]) + ce_1 = relay.annotation.compiler_end(r, "test") + f = relay.Function([a, b], ce_1) + mod = tvm.IRModule.from_expr(f) + return mod + + result = transform.AnnotateTarget("test")(before()) + expected = transform.InferType()(after()) + assert tvm.ir.structural_equal(expected, result) + + if __name__ == "__main__": test_multiple_ends() test_extern_dnnl() test_extern_dnnl_mobilenet() + test_composite_function()