Skip to content

Commit

Permalink
Non-Recursive AnnotatedTarget and MergeAnnotation
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Apr 22, 2020
1 parent 56941fb commit af5ff80
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 51 deletions.
85 changes: 39 additions & 46 deletions src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ const PackedFunc* make_end_op = runtime::Registry::Get("relay.op.annotation._mak

// A helper class to insert annotation boundaries for a program region that will
// be handled by a specific compiler.
class AnnotateTargetWrapper : public ExprMutator {
class AnnotateTargetRewriter : public ExprRewriter {
public:
explicit AnnotateTargetWrapper(Array<runtime::String> targets) : targets_(std::move(targets)) {}
explicit AnnotateTargetRewriter(Array<runtime::String> targets) : targets_(std::move(targets)) {}

/*!
* \brief This function annotates a compiler end and a compiler begin to all arguments.
Expand Down Expand Up @@ -108,29 +108,29 @@ class AnnotateTargetWrapper : public ExprMutator {
return new_op;
}

Expr VisitExpr_(const CallNode* cn) final {
Expr Rewrite_(const CallNode* pre, const Expr& post) final {
// Supported targets for this node. The order implies the priority.
std::vector<std::string> supported_targets;

auto op_node = cn->op.as<OpNode>();
auto op_node = pre->op.as<OpNode>();

// This graph has annotations, meaning that this is not the first time running this pass.
if (op_node && cn->op == compiler_begin_op) {
if (op_node && pre->op == compiler_begin_op) {
// Bypass compiler begin due to lack of target information. It will be processed
// when the following op handling arguments.
CHECK_EQ(cn->args.size(), 1U);
return VisitExpr(cn->args[0]);
} else if (op_node && cn->op == compiler_end_op) {
CHECK_EQ(pre->args.size(), 1U);
return post.as<CallNode>()->args[0];
} else if (op_node && pre->op == compiler_end_op) {
// Override compiler end with the new target.
CHECK_EQ(cn->args.size(), 1U);
auto input_expr = VisitExpr(cn->args[0]);
CHECK_EQ(pre->args.size(), 1U);
auto input_expr = post.as<CallNode>()->args[0];
CHECK(op_expr_to_target_.find(input_expr) != op_expr_to_target_.end());
return InsertAnnotation(input_expr, op_expr_to_target_[input_expr], make_end_op);
}

// Peek the first argument. If it is compiler begin then this node had annotated by
// another target before, so we also consider that target as a supported target.
const CallNode* first_arg_call = cn->args[0].as<CallNode>();
const CallNode* first_arg_call = pre->args[0].as<CallNode>();
if (first_arg_call && first_arg_call->op == compiler_begin_op) {
std::string arg_target = first_arg_call->attrs.as<CompilerAttrs>()->compiler;
if (arg_target != "default") {
Expand All @@ -142,21 +142,21 @@ class AnnotateTargetWrapper : public ExprMutator {
if (op_node) {
// TVM operators: Check target specific op checking function and add to supported_targets
// if it is supported.
Op op = Downcast<Op>(cn->op);
Op op = Downcast<Op>(pre->op);
CHECK(op.defined());
for (const auto& target : this->targets_) {
if (!Op::HasAttr("target." + std::string(target))) {
continue;
}
auto fannotate = Op::GetAttr<FTVMAnnotateTarget>("target." + std::string(target));
if (fannotate.count(op) && fannotate[op](cn->attrs, cn->args)) {
if (fannotate.count(op) && fannotate[op](pre->attrs, pre->args)) {
supported_targets.push_back(target);
}
}
} else if (cn->op->IsInstance<FunctionNode>()) {
} else if (pre->op->IsInstance<FunctionNode>()) {
// Composite function: Add the target of a composite function to supported_targets
// if it is in the target list.
Function func = Downcast<Function>(cn->op);
Function func = Downcast<Function>(pre->op);
CHECK(func.defined());

if (auto comp_name = func->GetAttr<String>(attr::kComposite)) {
Expand All @@ -181,50 +181,47 @@ class AnnotateTargetWrapper : public ExprMutator {
std::string target = supported_targets[0];

// Visit and mutate arguments after the target of this op has been determined.
auto new_call = Downcast<Call>(ExprMutator::VisitExpr_(cn));
Call post_call = Downcast<Call>(post);

// Add annotations to each arg.
auto target_n_args = AnnotateArgs(new_call->args, target);
auto target_n_args = AnnotateArgs(post_call->args, target);
Array<Expr> compiler_begins = std::get<1>(target_n_args);
Call call = Call(new_call->op, compiler_begins, new_call->attrs);
call->checked_type_ = cn->checked_type_;
Call new_call = Call(post_call->op, compiler_begins, post_call->attrs);
new_call->checked_type_ = pre->checked_type_;

// Update the target map.
op_expr_to_target_[call] = target;
op_expr_to_target_[new_call] = target;

return std::move(call);
return new_call;
}

Expr VisitExpr_(const TupleNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<Tuple>(new_e);
Expr Rewrite_(const TupleNode* op, const Expr& post) final {
auto expr = Downcast<Tuple>(post);

auto target_n_args = AnnotateArgs(expr->fields);
auto new_expr = Tuple(std::get<1>(target_n_args));
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}

Expr VisitExpr_(const TupleGetItemNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<TupleGetItem>(new_e);
Expr Rewrite_(const TupleGetItemNode* op, const Expr& post) final {
auto expr = Downcast<TupleGetItem>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->tuple}));
auto new_expr = TupleGetItem(std::get<1>(target_n_args)[0], expr->index);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}

Expr VisitExpr_(const FunctionNode* fn) final {
Expr Rewrite_(const FunctionNode* fn, const Expr& post) final {
Function func;
Expr new_body;
// don't step into composite functions
if (fn->GetAttr<String>(attr::kComposite).defined()) {
func = GetRef<Function>(fn);
new_body = func->body;
} else {
auto new_e = ExprMutator::VisitExpr_(fn);
func = Downcast<Function>(new_e);
func = Downcast<Function>(post);
new_body = func->body;
if (op_expr_to_target_.find(func->body) != op_expr_to_target_.end()) {
new_body = InsertAnnotation(func->body, op_expr_to_target_[func->body], make_end_op);
Expand All @@ -234,19 +231,17 @@ class AnnotateTargetWrapper : public ExprMutator {
return Function(func->params, new_body, func->ret_type, func->type_params, func->attrs);
}

Expr VisitExpr_(const LetNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
auto let = Downcast<Let>(new_e);
Expr Rewrite_(const LetNode* op, const Expr& post) final {
auto let = Downcast<Let>(post);

auto target_n_args = AnnotateArgs({let->value, let->body});
auto new_expr = Let(let->var, std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}

Expr VisitExpr_(const IfNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<If>(new_e);
Expr Rewrite_(const IfNode* op, const Expr& post) final {
auto expr = Downcast<If>(post);

auto target_n_args = AnnotateArgs({expr->cond, expr->true_branch, expr->false_branch});
CHECK_EQ(std::get<1>(target_n_args).size(), 3U);
Expand All @@ -256,29 +251,26 @@ class AnnotateTargetWrapper : public ExprMutator {
return std::move(new_expr);
}

Expr VisitExpr_(const RefCreateNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<RefCreate>(new_e);
Expr Rewrite_(const RefCreateNode* op, const Expr& post) final {
auto expr = Downcast<RefCreate>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->value}));
auto new_expr = RefCreate(std::get<1>(target_n_args)[0]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}

Expr VisitExpr_(const RefReadNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<RefRead>(new_e);
Expr Rewrite_(const RefReadNode* op, const Expr& post) final {
auto expr = Downcast<RefRead>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref}));
auto new_expr = RefRead(std::get<1>(target_n_args)[0]);
op_expr_to_target_[new_expr] = std::get<0>(target_n_args);
return std::move(new_expr);
}

Expr VisitExpr_(const RefWriteNode* op) final {
auto new_e = ExprMutator::VisitExpr_(op);
auto expr = Downcast<RefWrite>(new_e);
Expr Rewrite_(const RefWriteNode* op, const Expr& post) final {
auto expr = Downcast<RefWrite>(post);

auto target_n_args = AnnotateArgs(Array<Expr>({expr->ref, expr->value}));
auto new_expr = RefWrite(std::get<1>(target_n_args)[0], std::get<1>(target_n_args)[1]);
Expand All @@ -294,7 +286,8 @@ class AnnotateTargetWrapper : public ExprMutator {
};

Expr AnnotateTarget(const Expr& expr, const Array<runtime::String>& targets) {
return AnnotateTargetWrapper(targets).Mutate(expr);
auto rewriter = AnnotateTargetRewriter(targets);
return PostOrderRewrite(expr, &rewriter);
}

} // namespace annotate_target
Expand Down
11 changes: 6 additions & 5 deletions src/relay/transforms/merge_compiler_regions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ class RegionMerger : public ExprVisitor {
std::unordered_map<int, std::unordered_set<int>> region_restrictions_;
};

class MergeAnnotations : public ExprMutator {
class MergeAnnotations : public ExprRewriter {
public:
explicit MergeAnnotations(AnnotatedRegionSet regions) : regions_(regions) {}

Expr VisitExpr_(const CallNode* call) final {
Expr Rewrite_(const CallNode* call, const Expr& post) final {
// Merge annotations which are now internal to a region.
// This happens if we see a compiler begin next to a
// compiler end and they're both in the same region.
Expand All @@ -154,11 +154,12 @@ class MergeAnnotations : public ExprMutator {
auto region1 = regions_->GetRegion(GetRef<Call>(call));
auto region2 = regions_->GetRegion(arg);
if (region1 == region2) {
return VisitExpr(arg->args[0]);
auto post_arg = post.as<CallNode>()->args[0];
return post_arg.as<CallNode>()->args[0];
}
}
}
return ExprMutator::VisitExpr_(call);
return post;
}

private:
Expand All @@ -175,7 +176,7 @@ Expr MergeCompilerRegions(const Expr& expr) {

// Remove annotations that are not in the region boundaries.
MergeAnnotations merge_anno(regions);
return merge_anno.Mutate(expr);
return PostOrderRewrite(expr, &merge_anno);
}

} // namespace merge_compiler_region
Expand Down

0 comments on commit af5ff80

Please sign in to comment.