Skip to content

Commit

Permalink
[RELAY][BYOC] Support composite functions in AnnotateTarget
Browse files Browse the repository at this point in the history
This patch introduces support to annotate composite functions
in the AnnotateTarget pass. In order for a composite function
to be annotated, you should name it according to the style:

{codegen}.{name}
eg. dnnl.add_relu

Change-Id: I73d6c0b506153d866f6d1feb203b32dad59f2871
  • Loading branch information
mbaret committed Apr 7, 2020
1 parent 6c7d5d4 commit dd6f722
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 9 deletions.
40 changes: 31 additions & 9 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,24 @@ class AnnotateTargetWrapper : public ExprMutator {
if (expr->IsInstance<CallNode>()) {
Call call = Downcast<Call>(expr);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
if (call->op->IsInstance<OpNode>()) {
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
}
} else if (call->op->IsInstance<FunctionNode>()) {
// handle composite functions
Function func = Downcast<Function>(call->op);
CHECK(func.defined());
auto comp_name = func->GetAttr<tir::StringImm>(attr::kComposite);
if (comp_name.defined()) {
size_t i = comp_name->value.find('.');
if (i != std::string::npos) {
std::string target = comp_name->value.substr(0, i);
if (target == target_) return true;
}
}
}
}
if (expr->IsInstance<TupleGetItemNode>()) {
Expand All @@ -77,7 +91,6 @@ class AnnotateTargetWrapper : public ExprMutator {
}

Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn);

Call call = Downcast<Call>(new_e);
Expand Down Expand Up @@ -130,13 +143,22 @@ class AnnotateTargetWrapper : public ExprMutator {
}
}

Expr VisitExpr_(const FunctionNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);
Expr VisitExpr_(const FunctionNode* fn) {
Function func;
Expr new_body;
// don't step into composite functions
if (fn->GetAttr<tir::StringImm>(attr::kComposite).defined()) {
func = GetRef<Function>(fn);
new_body = func->body;
} else {
auto new_e = ExprMutator::VisitExpr_(fn);
func = Downcast<Function>(new_e);
new_body = InsertEnd(func->body);
}

auto func = Downcast<Function>(new_e);
return Function(
func->params,
InsertEnd(func->body),
new_body,
func->ret_type,
func->type_params,
func->attrs);
Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_annotate_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,53 @@ def after():
assert tvm.ir.structural_equal(expected, result)


def test_composite_function():
def before():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))

# add_relu function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))

# merged function
r = relay.Call(add_relu, [a, b])
f = relay.Function([a, b], r)
mod = tvm.IRModule.from_expr(f)
return mod

def after():
a = relay.var('a', shape=(10, 10))
b = relay.var('b', shape=(10, 10))

# add_relu function
in_1 = relay.var('in_1', shape=(10, 10))
in_2 = relay.var('in_2', shape=(10, 10))
add_node = relay.add(in_1, in_2)
relu_node = relay.nn.relu(add_node)
add_relu = relay.Function([in_1, in_2], relu_node)
add_relu = add_relu.with_attr("Composite", tvm.tir.StringImm("test.add_relu"))

# merged function
cb_1 = relay.annotation.compiler_begin(a, "test")
cb_2 = relay.annotation.compiler_begin(b, "test")
r = relay.Call(add_relu, [cb_1, cb_2])
ce_1 = relay.annotation.compiler_end(r, "test")
f = relay.Function([a, b], ce_1)
mod = tvm.IRModule.from_expr(f)
return mod

result = transform.AnnotateTarget("test")(before())
expected = transform.InferType()(after())
assert tvm.ir.structural_equal(expected, result)


if __name__ == "__main__":
test_multiple_ends()
test_extern_dnnl()
test_extern_dnnl_mobilenet()
test_composite_function()

0 comments on commit dd6f722

Please sign in to comment.