Skip to content

Commit

Permalink
[Relay] change device annotation from post DFS to recursive (apache#6124
Browse files Browse the repository at this point in the history
)

* change device annotation from post DFS to recursive

* add testcast for recursive device propogation
  • Loading branch information
zhanghaohit authored and trevor-m committed Sep 3, 2020
1 parent 6957795 commit 93f8384
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 49 deletions.
83 changes: 34 additions & 49 deletions src/relay/transforms/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -385,23 +385,34 @@ class DeviceInfo {
// TODO(zhiics) Skip annotation of function node for now.
}

void VisitExpr_(const ConstantNode* cn) final {
post_dfs_order_.push_back(std::make_pair(cn, has_copy_));
}
void VisitExpr_(const ConstantNode* cn) final { device_tag_[cn] = dev_type_; }

void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
if (!IsOnDeviceNode(call)) {
if (GetDeviceCopyNode(call)) {
if (const auto* node = GetDeviceCopyNode(call)) {
CHECK(node->IsInstance<CallNode>());
const auto* call_node = static_cast<const CallNode*>(node);
auto attrs = call_node->attrs.as<DeviceCopyAttrs>();

num_device_copy_ops_++;
bool has_copy_prev = has_copy_;
has_copy_ = true;
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
has_copy_ = has_copy_prev;
dev_type_ = attrs->src_dev_type;
for (auto& arg : call->args) {
Visit(arg);
// restore the type for remaining arguments
dev_type_ = attrs->src_dev_type;
}
device_tag_[call] = attrs->dst_dev_type;
// update the out_dev_type_, which should be the dst_dev_type of last copy
out_dev_type_ = attrs->dst_dev_type;
} else {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
for (auto& arg : call->args) {
int cur_dev_type = dev_type_;
Visit(arg);
// restore the type for remaining arguments
dev_type_ = cur_dev_type;
}
device_tag_[call] = dev_type_;
}
}
}
Expand All @@ -413,23 +424,22 @@ class DeviceInfo {

void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); }

void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
}
void VisitExpr_(const VarNode* vn) final { device_tag_[vn] = dev_type_; }

void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
device_tag_[ln] = dev_type_;
}

void VisitExpr_(const IfNode* in) final {
ExprVisitor::VisitExpr_(in);
post_dfs_order_.push_back(std::make_pair(in, has_copy_));
device_tag_[in] = dev_type_;
}

int num_device_copy_ops_{0};
bool has_copy_ = false;
std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
int dev_type_ = -1;
int out_dev_type_ = -1;
std::unordered_map<const ExprNode*, int> device_tag_;
friend DeviceInfo;
};

Expand All @@ -455,39 +465,14 @@ class DeviceInfo {
}

void PropagateDeviceId() {
// Bottom-up propagation.
int out_dev_type = BottomUpPropagation();
// propagation for remained nodes.
FillPropagation(out_dev_type);
}

int BottomUpPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
int out_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (const auto* node = GetDeviceCopyNode(it->first)) {
CHECK(node->IsInstance<CallNode>());
last_copy_node = static_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
if (it->second) device_map_.Set(GetRef<Expr>(it->first), attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it->first);
CHECK_EQ(device_map_.count(expr), 0U);
if (it->second) device_map_.Set(expr, cur_dev_type);
int out_dev_type = post_visitor_.out_dev_type_;
for (auto& it : post_visitor_.device_tag_) {
if (it.second != -1) {
device_map_.Set(GetRef<Expr>(it.first), it.second);
} else {
device_map_.Set(GetRef<Expr>(it.first), out_dev_type);
}
}
return out_dev_type;
}

void FillPropagation(int out_dev_type) {
for (const auto& it : post_visitor_.post_dfs_order_) {
Expr expr = GetRef<Expr>(it.first);
if (!it.second) device_map_.Set(expr, out_dev_type);
}
}

PostDfsOrderVisitor post_visitor_;
Expand Down
70 changes: 70 additions & 0 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,76 @@ def test_visitor_annotation():
test_visitor_annotation()


def test_propogation():
R""" The network and device type is as following:
x 1
|
log 1
/ \
log2 log10 2
\ /
add 2
|
tan 1
"""
ctx1 = tvm.context(1)
ctx2 = tvm.context(2)

expected_dev_type = {
'log': ctx1,
'log2': ctx2,
'log10': ctx2,
'add': ctx2,
'tan': ctx1
}

x = relay.var("x", shape=(3,))

def annotated():
log = relay.log(x)
_log = relay.annotation.on_device(log, expected_dev_type['log'])
log2 = relay.log2(_log)
_log2 = relay.annotation.on_device(log2, expected_dev_type['log2'])
log10 = relay.log10(_log)
_log10 = relay.annotation.on_device(log10, expected_dev_type['log10'])
add = relay.add(_log2, _log10)
_add = relay.annotation.on_device(add, expected_dev_type['add'])
tan = relay.tan(_add)
_tan = relay.annotation.on_device(tan, expected_dev_type['tan'])

func = run_opt_pass(_tan, transform.RewriteAnnotatedOps(ctx1.device_type))
return func

def expected():
log = relay.log(x)
_log_left = relay.device_copy(log, ctx1, ctx2)
_log_right = relay.device_copy(log, ctx1, ctx2)
log2 = relay.log2(_log_left)
log10 = relay.log10(_log_right)
add = relay.add(log2, log10)
_add = relay.device_copy(add, ctx2, ctx1)
tan = relay.tan(_add)

func = run_opt_pass(tan, transform.InferType())
return func

annotated_expr = annotated()
expected_expr = expected()
assert tvm.ir.structural_equal(annotated_expr, expected_expr)

smap = relay.backend._backend.GraphPlanMemory(annotated_expr)
for expr, storage_dev_type in smap.items():
# x is ctx1 as output is ctx1
if isinstance(expr, tvm.relay.expr.Var):
assert storage_dev_type[1][0] == ctx1.device_type
else:
# device_copy op should be its dst_dev_type
if isinstance(expr.attrs, tvm.relay.op.op_attrs.DeviceCopyAttrs):
assert storage_dev_type[1][0] == expr.attrs.dst_dev_type
else:
assert storage_dev_type[1][0] == expected_dev_type[expr.op.name].device_type


def run_fusible_network(dev, tgt):
R""" The network is as following:
x y
Expand Down

0 comments on commit 93f8384

Please sign in to comment.