Skip to content

Commit

Permalink
[RELAY] Add MergeCompilerRegions pass (apache#5134)
Browse files Browse the repository at this point in the history
* [RELAY] Add MergeCompilerRegions pass

This pass is part of the flow to support creating compiler
regions with multiple outputs. It should be called after
AnnotateTarget and will merge together regions that share
the same target to create larger compiler regions that can
be off-loaded to external codegens.

This pass implements an algorithm to ensure that during the
merging, no data dependency issues are created. See the tests
for an example of this case.

Co-authored-by: Ramana Radhakrishnan  <[email protected]>
Co-authored-by: Manupa Karunaratne    <[email protected]>

Change-Id: Ibd99083564608d888482f57c5080109f3eefec88

* [RELAY] Annotate compiler_ends on each edge

This alters the behaviour of the AnnotateTarget
pass to enforce the property that all compiler
annotations exist along a single data flow edge.
Specifically, this means they should have exactly
one parent and one child.

Change-Id: I0e74803a77767f4f377d17755a13a74a30909797

* Fix comment

* Rebase *Node::make

* Moved block outside for loop

* Code style

* Update make API

* Remove comment

* Remove redundant 'else's

* Make one line

* Fix comment

* RefWrite

* Fix merge ordering

* Add the RFC example as a test

* [FIX] Fixed merging behaviour in AnnotateRegionSet

Deleting items from a list while iterating it seems to
result in undefined behaviour which sometimes segfaults.
This makes sure all the item deletion happens separately.

* Added checks

* Move comment

* Update comments
  • Loading branch information
mbaret authored and Trevor Morris committed Apr 16, 2020
1 parent 31bbc63 commit b276800
Show file tree
Hide file tree
Showing 6 changed files with 726 additions and 27 deletions.
11 changes: 11 additions & 0 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,17 @@ def MergeComposite(pattern_table):
return _ffi_api.MergeComposite(pattern_names, patterns)


def MergeCompilerRegions():
"""Merge together compiler regions.
Returns
-------
ret : tvm.relay.Pass
The registered pass that merges compiler regions.
"""
return _ffi_api.MergeCompilerRegions()


def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
Expand Down
6 changes: 5 additions & 1 deletion src/relay/analysis/annotated_region_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,18 @@ void AnnotatedRegionSetNode::MergeRegions(AnnotatedRegion src,
}
// if any of the outputs of src are inputs of dest, they become internal nodes
// so remove them from outs
std::vector<Expr> ins_to_remove;
for (const auto& input : dest->ins) {
auto call = Downcast<Call>(input);
auto it = std::find(src->outs.begin(), src->outs.end(), call->args[0]);
if (it != src->outs.end()) {
dest->outs.remove(*it);
dest->ins.remove(input);
ins_to_remove.push_back(input);
}
}
for (const auto& input : ins_to_remove) {
dest->ins.remove(input);
}
regions_.erase(src);
}

Expand Down
150 changes: 124 additions & 26 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,46 +38,144 @@ class AnnotateTargetWrapper : public ExprMutator {
public:
explicit AnnotateTargetWrapper(const std::string& target) : target_(target) {}

Expr Annotate(const Expr& expr) {
return InsertEnd(Mutate(expr));
}

bool IsSupported(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
Call call = Downcast<Call>(expr);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());
if (fannotate.count(op)) {
return fannotate[op](call->attrs, call->args);
}
}
return false;
}

Expr InsertEnd(const Expr& arg) {
if (IsSupported(arg)) {
const auto *end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(arg, target_);
return end;
}
return arg;
}

Expr VisitExpr_(const CallNode* cn) {
// TODO(@zhiics, @comaniac) Handle composite functions.
auto new_e = ExprMutator::VisitExpr_(cn);

Call call = Downcast<Call>(new_e);
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + target_);
Op op = Downcast<Op>(call->op);
CHECK(op.defined());

if (fannotate.count(op)) {
bool external = fannotate[op](call->attrs, call->args);
if (external) {
tvm::Array<tvm::relay::Expr> compiler_begins;
for (const auto& it : call->args) {
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
CHECK(begin_op);
Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin);
}
Expr update_call = Call(call->op, compiler_begins, call->attrs);
const auto* end_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_end");
CHECK(end_op);
Expr end = (*end_op)(update_call, target_);
return end;

// add end annotations if the args are supported
Array<Expr> compiler_ends;
for (const auto& it : call->args) {
compiler_ends.push_back(InsertEnd(it));
}
call = Call(call->op, compiler_ends, call->attrs);

// add begin annotations if the call node is supported
if (IsSupported(call)) {
tvm::Array<tvm::relay::Expr> compiler_begins;
const auto* begin_op =
runtime::Registry::Get("relay.op.annotation._make.compiler_begin");
for (const auto& it : call->args) {
CHECK(begin_op);
Expr begin = (*begin_op)(it, target_);
compiler_begins.push_back(begin);
}
} else {
LOG(WARNING) << op->name << " in " << target_
<< " is not registered. It will be executed on CPU.";
call = Call(call->op, compiler_begins, call->attrs);
}
return new_e;

return std::move(call);
}

Expr VisitExpr_(const TupleNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto tup = Downcast<Tuple>(new_e);
Array<Expr> new_fields;
for (auto field : tup->fields) {
new_fields.push_back(InsertEnd(field));
}
return Tuple(new_fields);
}

Expr VisitExpr_(const TupleGetItemNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto get = Downcast<TupleGetItem>(new_e);
return TupleGetItem(
InsertEnd(get->tuple),
get->index);
}

Expr VisitExpr_(const FunctionNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto func = Downcast<Function>(new_e);
return Function(
func->params,
InsertEnd(func->body),
func->ret_type,
func->type_params,
func->attrs);
}

Expr VisitExpr_(const LetNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto let = Downcast<Let>(new_e);
return Let(
let->var,
InsertEnd(let->value),
InsertEnd(let->body));
}

Expr VisitExpr_(const IfNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto iff = Downcast<If>(new_e);
return If(
InsertEnd(iff->cond),
InsertEnd(iff->true_branch),
InsertEnd(iff->false_branch));
}

Expr VisitExpr_(const RefCreateNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto create = Downcast<RefCreate>(new_e);
return RefCreate(InsertEnd(create->value));
}

Expr VisitExpr_(const RefReadNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto read = Downcast<RefRead>(new_e);
return RefRead(InsertEnd(read->ref));
}

Expr VisitExpr_(const RefWriteNode* op) {
auto new_e = ExprMutator::VisitExpr_(op);

auto write = Downcast<RefWrite>(new_e);
return RefWrite(
InsertEnd(write->ref),
InsertEnd(write->value));
}

private:
std::string target_;
};

Expr AnnotateTarget(const Expr& expr, const std::string& target) {
return AnnotateTargetWrapper(target).Mutate(expr);
return AnnotateTargetWrapper(target).Annotate(expr);
}

} // namespace annotate_target
Expand Down
Loading

0 comments on commit b276800

Please sign in to comment.