Skip to content

Commit

Permalink
Non-Recursive AnnotatedRegionSet and RegionMerger
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Apr 22, 2020
1 parent af5ff80 commit a37b1dc
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 105 deletions.
133 changes: 65 additions & 68 deletions src/relay/analysis/annotated_region_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,32 +86,69 @@ AnnotatedRegion AnnotatedRegionSetNode::MakeRegion(const std::string& target) {
return *ret.first;
}

class AnnotatedRegionSet::Creator : public ExprVisitor {
class AnnotatedRegionSet::Creator : protected MixedModeVisitor {
public:
Creator(const Op& region_begin_op, const Op& region_end_op)
: begin_op_(region_begin_op), end_op_(region_end_op) {}

AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}

void AddToArgRegion(Expr expr, Array<Expr> args) {
// Merge argument regions and add itself to the region.

// Find the first open region.
AnnotatedRegion region;
for (auto arg : args) {
const CallNode* end = arg.as<CallNode>();
if (end && end->op == end_op_) { // Ignore closed regions.
continue;
}

region = region_set_->GetRegion(arg);
if (region.defined()) {
break;
}
}

// Try to merge open regions.
for (auto arg : args) {
const CallNode* end = arg.as<CallNode>();
if (end && end->op == end_op_) { // Ignore closed regions.
continue;
}

auto arg_region = region_set_->GetRegion(arg);
CHECK_EQ(region.defined(), arg_region.defined())
<< "Arg regions are inconsistent: " << AsText(expr);
if (region.defined() && region != arg_region) {
region_set_->MergeRegions(arg_region, region);
}
}
if (region.defined()) {
region_set_->AddToRegion(region, expr);
}
}

void VisitExpr_(const CallNode* call) {
auto op_node = call->op.as<OpNode>();

if (op_node == nullptr || call->attrs.as<CompilerAttrs>() == nullptr) {
// Propagate region to arguments
auto region = region_set_->GetRegion(GetRef<Call>(call));
if (region.defined()) {
for (auto arg : call->args) {
region_set_->AddToRegion(region, arg);
}
}
AddToArgRegion(GetRef<Call>(call), call->args);
} else if (call->op == begin_op_) {
// The annotation node is inserted on edge so it must have only one argument.
CHECK_EQ(call->args.size(), 1U);
std::string target = call->attrs.as<CompilerAttrs>()->compiler;

// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(GetRef<Call>(call));
if (!region.defined()) {
throw Error(ErrorBuilder()
<< "Cannot find the corresponding region for start annotation:\n"
<< AsText(GetRef<Call>(call), false));
}
CHECK(!region.defined());

// Create a new region.
region = region_set_->MakeRegion(target);
region->nodes_.insert(GetRef<Call>(call));
region->ins_.push_back(GetRef<Call>(call));
} else {
CHECK_EQ(call->op, end_op_);
Expand All @@ -122,9 +159,8 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
// Check if the argument already belongs to a region
auto region = region_set_->GetRegion(call->args[0]);
if (!region.defined()) {
// Create a new region if the argument is not belonged to any regions yet.
region = region_set_->MakeRegion(target);
region->nodes_.insert(call->args[0]);
throw Error(ErrorBuilder() << "Cannot find the corresponding region for end annotation:\n"
<< AsText(GetRef<Call>(call), false));
} else {
// If the argument is belonged to a region, it must have the same target.
// Otherwise we should see a region_begin op.
Expand All @@ -133,83 +169,44 @@ class AnnotatedRegionSet::Creator : public ExprVisitor {
region->nodes_.insert(GetRef<Call>(call));
region->outs_.push_back(GetRef<Call>(call));
}
ExprVisitor::VisitExpr_(call);
}

AnnotatedRegionSet Create(const Expr& expr) {
VisitExpr(expr);
return std::move(region_set_);
}

void VisitExpr_(const TupleNode* op) {
auto region = region_set_->GetRegion(GetRef<Tuple>(op));
if (region.defined()) {
for (auto field : op->fields) {
region_set_->AddToRegion(region, field);
}
}
ExprVisitor::VisitExpr_(op);
AddToArgRegion(GetRef<Tuple>(op), op->fields);
}

void VisitExpr_(const TupleGetItemNode* g) {
auto region = region_set_->GetRegion(GetRef<TupleGetItem>(g));
if (region.defined()) {
region_set_->AddToRegion(region, g->tuple);
}
ExprVisitor::VisitExpr_(g);
}

void VisitExpr_(const FunctionNode* op) {
auto region = region_set_->GetRegion(GetRef<Function>(op));
if (region.defined()) {
for (auto param : op->params) {
region_set_->AddToRegion(region, param);
}
}
ExprVisitor::VisitExpr_(op);
Array<Expr> args = {g->tuple};
AddToArgRegion(GetRef<TupleGetItem>(g), args);
}

void VisitExpr_(const LetNode* op) {
auto region = region_set_->GetRegion(GetRef<Let>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->var);
region_set_->AddToRegion(region, op->value);
region_set_->AddToRegion(region, op->body);
}
Array<Expr> args = {op->var, op->value, op->body};
AddToArgRegion(GetRef<Let>(op), args);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const IfNode* op) {
auto region = region_set_->GetRegion(GetRef<If>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->cond);
region_set_->AddToRegion(region, op->true_branch);
region_set_->AddToRegion(region, op->false_branch);
}
Array<Expr> args = {op->cond, op->true_branch, op->false_branch};
AddToArgRegion(GetRef<If>(op), args);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const RefCreateNode* op) {
auto region = region_set_->GetRegion(GetRef<RefCreate>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->value);
}
Array<Expr> args = {op->value};
AddToArgRegion(GetRef<RefCreate>(op), args);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const RefReadNode* op) {
auto region = region_set_->GetRegion(GetRef<RefRead>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->ref);
}
Array<Expr> args = {op->ref};
AddToArgRegion(GetRef<RefRead>(op), args);
ExprVisitor::VisitExpr_(op);
}

void VisitExpr_(const RefWriteNode* op) {
auto region = region_set_->GetRegion(GetRef<RefWrite>(op));
if (region.defined()) {
region_set_->AddToRegion(region, op->ref);
}
Array<Expr> args = {op->ref};
AddToArgRegion(GetRef<RefWrite>(op), args);
ExprVisitor::VisitExpr_(op);
}

Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/annotate_target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ class AnnotateTargetRewriter : public ExprRewriter {
// Update the target map.
op_expr_to_target_[new_call] = target;

return new_call;
return std::move(new_call);
}

Expr Rewrite_(const TupleNode* op, const Expr& post) final {
Expand Down
3 changes: 1 addition & 2 deletions src/relay/transforms/merge_compiler_regions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace merge_compiler_region {
static const Op& compiler_begin_op = Op::Get("annotation.compiler_begin");
static const Op& compiler_end_op = Op::Get("annotation.compiler_end");

class RegionMerger : public ExprVisitor {
class RegionMerger : public MixedModeVisitor {
public:
explicit RegionMerger(AnnotatedRegionSet regions) : regions_(regions) {}

Expand Down Expand Up @@ -131,7 +131,6 @@ class RegionMerger : public ExprVisitor {
}
merged_regions_.insert(region->GetID());
}
ExprVisitor::VisitExpr_(call);
}

private:
Expand Down
69 changes: 35 additions & 34 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,8 @@ def expected():
bn = relay.nn.batch_norm(data0, bn_gamma, bn_beta, bn_mmean, bn_mvar)
func0 = relay.Function([data0, bn_gamma, bn_beta, bn_mmean, bn_mvar],
bn.astuple())
func0 = set_func_attr(func0, "test_compiler", "test_compiler_0")
gv0 = relay.GlobalVar("test_compiler_0")
func0 = set_func_attr(func0, "test_compiler", "test_compiler_2")
gv0 = relay.GlobalVar("test_compiler_2")
mod[gv0] = func0

# function for conv2d
Expand All @@ -536,8 +536,8 @@ def expected():
channels=16,
padding=(1, 1))
func1 = relay.Function([data1, weight1], conv)
func1 = set_func_attr(func1, "test_compiler", "test_compiler_1")
gv1 = relay.GlobalVar("test_compiler_1")
func1 = set_func_attr(func1, "test_compiler", "test_compiler_0")
gv1 = relay.GlobalVar("test_compiler_0")
mod[gv1] = func1

# main function
Expand Down Expand Up @@ -630,7 +630,6 @@ def test_constant_propagation():

def expected():
mod = tvm.IRModule()
x = relay.const(ones)
y = relay.var("y", shape=(8, 8))
x0 = relay.const(ones)
y0 = relay.var("y0", shape=(8, 8))
Expand Down Expand Up @@ -712,12 +711,12 @@ def expected():
mod = tvm.IRModule()

# function 0
data = relay.var("test_target_2_i0", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("test_target_2_i1", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("test_target_2_i2", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("test_target_2_i3", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("test_target_2_i4", relay.TensorType((16, ), "float32"))
bn_var = relay.var("test_target_2_i5", relay.TensorType((16, ), "float32"))
data = relay.var("test_target_0_i0", relay.TensorType((1, 3, 224, 224), "float32"))
weight = relay.var("test_target_0_i1", relay.TensorType((16, 3, 3, 3), "float32"))
bn_gamma = relay.var("test_target_0_i2", relay.TensorType((16, ), "float32"))
bn_beta = relay.var("test_target_0_i3", relay.TensorType((16, ), "float32"))
bn_mean = relay.var("test_target_0_i4", relay.TensorType((16, ), "float32"))
bn_var = relay.var("test_target_0_i5", relay.TensorType((16, ), "float32"))

conv_o = relay.nn.conv2d(
data=data,
Expand All @@ -730,12 +729,12 @@ def expected():
bn_var)

relu_o = relay.nn.relu(bn_o[0])
tuple_o = relay.Tuple((bn_o[2], bn_o[1], relu_o))
tuple_o = relay.Tuple((relu_o, bn_o[1], bn_o[2]))

func0 = relay.Function([data, weight, bn_gamma, bn_beta,
bn_mean, bn_var], tuple_o)
func0 = set_func_attr(func0, "test_target", "test_target_2")
gv0 = relay.GlobalVar("test_target_2")
func0 = set_func_attr(func0, "test_target", "test_target_0")
gv0 = relay.GlobalVar("test_target_0")
mod[gv0] = func0

# body
Expand All @@ -747,9 +746,9 @@ def expected():
bn_var = relay.var("bn_var", relay.TensorType((16, ), "float32"))

f0_o = gv0(data, weight, bn_gamma, bn_beta, bn_mean, bn_var)
f0_relu_o = relay.TupleGetItem(f0_o, 2)
f0_relu_o = relay.TupleGetItem(f0_o, 0)
f0_mean_o = relay.TupleGetItem(f0_o, 1)
f0_var_o = relay.TupleGetItem(f0_o, 0)
f0_var_o = relay.TupleGetItem(f0_o, 2)

f0_mean_abs = relay.abs(f0_mean_o)
f0_var_abs = relay.abs(f0_var_o)
Expand All @@ -763,7 +762,9 @@ def expected():
mod = tvm.IRModule()
mod["main"] = create_graph()
ref_mod = expected()
print(ref_mod)
partitioned = transform.PartitionGraph()(mod)
print(partitioned)
assert tvm.ir.structural_equal(partitioned, ref_mod, map_free_vars=True)


Expand Down Expand Up @@ -791,22 +792,22 @@ def expected():
mod = tvm.IRModule()

# function 1
f1_cb1 = relay.var('test_target_1_i0', shape=(10, 10))
f1_cb1 = relay.var('test_target_0_i0', shape=(10, 10))
f1_O_1 = relay.abs(f1_cb1)
f1_O_2 = relay.nn.relu(f1_O_1)
f1_out = relay.Tuple((f1_O_2, f1_O_1))
func1 = relay.Function([f1_cb1], f1_out)
func1 = set_func_attr(func1, "test_target", "test_target_1")
gv1 = relay.GlobalVar("test_target_1")
func1 = set_func_attr(func1, "test_target", "test_target_0")
gv1 = relay.GlobalVar("test_target_0")
mod[gv1] = func1

# function 0
f2_cb3 = relay.var('test_target_0_i0', shape=(10, 10))
f2_cb4 = relay.var('test_target_0_i1', shape=(10, 10))
f2_cb3 = relay.var('test_target_1_i0', shape=(10, 10))
f2_cb4 = relay.var('test_target_1_i1', shape=(10, 10))
f2_O_3 = relay.add(f2_cb3, f2_cb4)
func0 = relay.Function([f2_cb3, f2_cb4], f2_O_3)
func0 = set_func_attr(func0, "test_target", "test_target_0")
gv0 = relay.GlobalVar("test_target_0")
func0 = set_func_attr(func0, "test_target", "test_target_1")
gv0 = relay.GlobalVar("test_target_1")
mod[gv0] = func0

# body
Expand Down Expand Up @@ -1109,22 +1110,22 @@ def expected():
mod = tvm.IRModule()

# function 0
f0_i0 = relay.var(target+"_1_i0", shape=(10, 10))
f0_i1 = relay.var(target+"_1_i1")
f0_i2 = relay.var(target+"_1_i2")
f0_i3 = relay.var(target+"_1_i3")
f0_i4 = relay.var(target+"_1_i4")
f0_i0 = relay.var(target + "_0_i0", shape=(10, 10))
f0_i1 = relay.var(target + "_0_i1")
f0_i2 = relay.var(target + "_0_i2")
f0_i3 = relay.var(target + "_0_i3")
f0_i4 = relay.var(target + "_0_i4")
f0_n0 = relay.nn.batch_norm(f0_i0, f0_i1, f0_i2, f0_i3, f0_i4)
f0_n1 = f0_n0[1]
f0_n2 = relay.nn.relu(f0_n0[0])
f0_o0 = relay.Tuple([f0_n1, f0_n2])
f0_o0 = relay.Tuple([f0_n2, f0_n1])
func0 = relay.Function([f0_i0, f0_i1, f0_i2, f0_i3, f0_i4], f0_o0)

func0 = func0.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Inline", tvm.tir.IntImm("int32", 1))
func0 = func0.with_attr("Compiler", target)
func0 = func0.with_attr("global_symbol", target+"_1")
gv0 = relay.GlobalVar(target+"_1")
func0 = func0.with_attr("global_symbol", target + "_0")
gv0 = relay.GlobalVar(target + "_0")
mod[gv0] = func0

# body
Expand All @@ -1136,9 +1137,9 @@ def expected():
function_out = gv0(data, bn_gamma, bn_beta, bn_mmean, bn_mvar)
get_out0 = relay.TupleGetItem(function_out, 0)
get_out1 = relay.TupleGetItem(function_out, 1)
out_2 = relay.tanh(get_out0)
out_3 = relay.log(get_out0)
out = relay.Tuple([get_out1, out_2, out_3])
out_2 = relay.tanh(get_out1)
out_3 = relay.log(get_out1)
out = relay.Tuple([get_out0, out_2, out_3])
func = relay.Function([data, bn_gamma, bn_beta, bn_mmean, bn_mvar], out)
mod["main"] = func
return mod
Expand Down

0 comments on commit a37b1dc

Please sign in to comment.