Skip to content

Commit

Permalink
[BYOC] Use Non-Recursive Visitor/Mutator (apache#5410)
Browse files Browse the repository at this point in the history
* Non-Recursive AnnotatedTarget and MergeAnnotation

* Non-Recursive AnnotatedRegionSet and RegionMerger
  • Loading branch information
comaniac authored and trevor-m committed Jun 18, 2020
1 parent f103631 commit 7cc370f
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 155 deletions.
133 changes: 65 additions & 68 deletions src/relay/analysis/annotated_region_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,69 @@ AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
return *ret.first;
}

class AnnotatedRegionSet::Creator : public ExprVisitor {
class AnnotatedRegionSet::Creator : protected MixedModeVisitor {
public:
Creator(const Op& region_begin_op, const Op& region_end_op)
: begin_op_(region_begin_op), end_op_(region_end_op) {}

AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}

void AddToArgRegion(Expr expr, Array<Expr> args) {
// Merge argument regions and add itself to the region.

// Find the first open region.
AnnotatedRegion region;
for (auto arg : args) {
const CallNode* end = arg.as<CallNode>();
if (end && end->op == end_op_) { // Ignore closed regions.
continue;
}

region = region_set_->GetRegion(arg);
if (region.defined()) {
break;
}
}

// Try to merge open regions.
for (auto arg : args) {
const CallNode* end = arg.as<CallNode>();
if (end && end->op == end_op_) { // Ignore closed regions.
continue;
}

auto arg_region = region_set_->GetRegion(arg);
CHECK_EQ(region.defined(), arg_region.defined())
<< "Arg regions are inconsistent: " << AsText(expr);
if (region.defined() && region != arg_region) {
region_set_->MergeRegions(arg_region, region);
}
}
if (region.defined()) {
region_set_->AddToRegion(region, expr);
}
}

void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>();

if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
// Propagate region to arguments
auto region = region_set_->GetRegion(GetRef<Call>(call));
if (region.defined()) {
for (auto arg : call->args) {
region_set_->AddToRegion(region, arg);
}
}
AddToArgRegion(GetRef<Call>(call), call->args);
} else if (call->op == begin_op_) {
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
std::string target = call->attrs.as<CompilerAttrs>()->compiler;

// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(GetRef<Call>(call));
if (!region.defined()) {
throw Error(ErrorBuilder()
<< "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
CHECK(!region.defined());

// Create a new region.
region = region_set_->MakeRegion(target);
region->nodes_.insert(GetRef<Call>(call));
region->ins_.push_back(GetRef<Call>(call));
} else {
CHECK_EQ(call->op, end_op_);
Expand All @@ -122,9 +159,8 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) {
// Create a new region if the argument is not belonged to any regions yet.
region = region_set_->MakeRegion(target);
region->nodes_.insert(call->args[0]);
throw Error(ErrorBuilder() << "Cannot find the corresponding region for end annotation:\n"
<< AsText(GetRef<Call>(call), false));
} else {
// If the argument is belonged to a region, it must have the same target.
// Otherwise we should see a region_begin op.
Expand All @@ -133,83 +169,44 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
region->nodes_.insert(GetRef<Call>(call));
region->outs_.push_back(GetRef<Call>(call));
}
ExprVisitor::VisitExpr_(call);
}

AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}

void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) {
for (auto field : op->fields) {
region_set_->AddToRegion(region, field);
}
}
ExprVisitor::VisitExpr_(op);
AddToArgRegion(GetRef<Tuple>(op), op->fields);
}

void VisitExpr_(const TupleGetItemNode* g) {
auto region = region_set_->GetRegion(GetRef<TupleGetItem>(g));
if (region.defined()) {
region_set_->AddToRegion(region, g->tuple);
}
ExprVisitor::VisitExpr_(g);
}

void VisitExpr_(const FunctionNode* op) {
auto region = region_set_->GetRegion(GetRef<Function>(op));
if (region.defined()) {
for (auto param : op->params) {
region_set_->AddToRegion(region, param);
}
}
ExprVisitor::VisitExpr_(op);
Array<Expr> args = {g->tuple};
AddToArgRegion(GetRef<TupleGetItem>(g), args);
}

void VisitExpr_(const LetNode* op) {
auto region = region_set_->GetRegion(GetRef<Let>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->var);
region_set_->AddToRegion(region, op->value);
region_set_->AddToRegion(region, op->body);
}
Array<Expr> args = {op->var, op->value, op->body};
AddToArgRegion(GetRef<Let>(op), args);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const IfNode* op) {
auto region = region_set_->GetRegion(GetRef<If>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->cond);
region_set_->AddToRegion(region, op->true_branch);
region_set_->AddToRegion(region, op->false_branch);
}
Array<Expr> args = {op->cond, op->true_branch, op->false_branch};
AddToArgRegion(GetRef<If>(op), args);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const RefCreateNode* op) {
auto region = region_set_->GetRegion(GetRef<RefCreate>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->value);
}
Array<Expr> args = {op->value};
AddToArgRegion(GetRef<RefCreate>(op), args);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const RefReadNode* op) {
auto region = region_set_->GetRegion(GetRef<RefRead>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->ref);
}
Array<Expr> args = {op->ref};
AddToArgRegion(GetRef<RefRead>(op), args);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const RefWriteNode* op) {
auto region = region_set_->GetRegion(GetRef<RefWrite>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->ref);
}
Array<Expr> args = {op->ref};
AddToArgRegion(GetRef<RefWrite>(op), args);
ExprVisitor::VisitExpr_(op);
}

Expand Down
Loading

0 comments on commit 7cc370f

Please sign in to comment.