Skip to content

Commit

Permalink
[Heterogeneous][Bugfix] Fix bug of wrongly generated device_map (apac…
Browse files Browse the repository at this point in the history
…he#2990)

* fix bug of device_index

* cpplint

* nose

* Update test_pass_annotation.py

* fix name of testcase

* delete comment
  • Loading branch information
imorinaga authored and wweic committed May 13, 2019
1 parent 742fc2e commit 113a325
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 41 deletions.
71 changes: 38 additions & 33 deletions src/relay/pass/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,9 @@ class AnnotatationVisitor : private ExprVisitor {
* -Pass 1: Propagating the source device type to ops in a bottom-up way to the
* ancestors until encountering another copy op. For example, this way
* provides add, x, and y device types from the copy operator, `copy1`.
* -Pass 2: Propagating the destination device type of "the last" copy op in a
* top-down manner to the nodes on the output paths. For instance,
* this offers `subtract` and `exp` the same device type as `copy3`.
* -Pass 2: Propagating the destination device type of "the last" copy op to the
* remain nodes. For instance, this offers `subtract` and `exp` the
* same device type as `copy3`.
*/

class DeviceInfo {
Expand Down Expand Up @@ -371,17 +371,22 @@ class DeviceInfo {
}

void VisitExpr_(const ConstantNode* cn) final {
post_dfs_order_.push_back(cn);
post_dfs_order_.push_back(std::make_pair(cn, has_copy_));
}

void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
if (!IsOnDeviceNode(call)) {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(call);

if (GetDeviceCopyNode(call)) {
num_device_copy_ops_++;
bool has_copy_prev = has_copy_;
has_copy_ = true;
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
has_copy_ = has_copy_prev;
} else {
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
}
}
}
Expand All @@ -393,23 +398,27 @@ class DeviceInfo {

void VisitExpr_(const TupleGetItemNode* op) final {
ExprVisitor::VisitExpr_(op);
post_dfs_order_.push_back(op);
std::make_pair(op, has_copy_);
}

void VisitExpr_(const VarNode* vn) final { post_dfs_order_.push_back(vn); }
void VisitExpr_(const VarNode* vn) final {
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
}

void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
post_dfs_order_.push_back(ln);
post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
}

void VisitExpr_(const IfNode* in) final {
ExprVisitor::VisitExpr_(in);
post_dfs_order_.push_back(in);
post_dfs_order_.push_back(std::make_pair(in, has_copy_));
}


int num_device_copy_ops_{0};
std::vector<const ExprNode*> post_dfs_order_;
bool has_copy_ = false;
std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
friend DeviceInfo;
};

Expand All @@ -435,46 +444,41 @@ class DeviceInfo {

void PropagateDeviceId() {
// Bottom-up propagation.
BottomUpPropagation();
// Top-down propagation.
TopDownPropagation();
int out_dev_type = BottomUpPropagation();
// propagation for remained nodes.
FillPropagation(out_dev_type);
}

void BottomUpPropagation() {
int BottomUpPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
int out_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (const auto* node = GetDeviceCopyNode(*it)) {
if (const auto* node = GetDeviceCopyNode(it->first)) {
last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
device_map_.Set(GetRef<Expr>(*it), attrs->dst_dev_type);
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
if (it->second) device_map_.Set(GetRef<Expr>(it->first),
attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(*it);
Expr expr = GetRef<Expr>(it->first);
CHECK_EQ(device_map_.count(expr), 0U);
device_map_.Set(expr, cur_dev_type);
if (it->second) device_map_.Set(expr, cur_dev_type);
}
}
return out_dev_type;
}

void TopDownPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
void FillPropagation(int out_dev_type) {
for (const auto& it : post_visitor_.post_dfs_order_) {
if (const auto* node = GetDeviceCopyNode(it)) {
last_copy_node = dynamic_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->dst_dev_type;
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it);
if (device_map_.count(expr) == 0) {
device_map_.Set(expr, cur_dev_type);
}
}
Expr expr = GetRef<Expr>(it.first);
if (!it.second) device_map_.Set(expr, out_dev_type);
}
}


PostDfsOrderVisitor post_visitor_;
Map<Expr, Integer> device_map_;
};
Expand Down Expand Up @@ -503,3 +507,4 @@ TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps")

} // namespace relay
} // namespace tvm

92 changes: 84 additions & 8 deletions tests/python/relay/test_pass_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def check_storage_and_device_types():
check_storage_and_device_types()


def test_fusible_network():
def run_fusible_network(dev, tgt):
R""" The network is as following:
x y
\ /
Expand Down Expand Up @@ -417,20 +417,96 @@ def test_fallback_all_operators(device, tgt):
check_annotated_graph(annotated_func, expected_func)
test_runtime(target, device, annotated_func)


test_fuse_log_add(dev, tgt)
test_fuse_all(dev, tgt)
test_fallback_exp(dev, tgt)
test_fallback_all_operators(dev, tgt)

def run_unpropagatable_graph(dev, tgt):
R""" The network is as following:
a b c d
\ / \ /
add mul
\ /
subtract
"""

a = relay.var("a", shape=(10, 10))
b = relay.var("b", shape=(10, 10))
c = relay.var("c", shape=(10, 10))
d = relay.var("d", shape=(10, 10))
a_data = np.random.rand(10, 10).astype('float32')
b_data = np.random.rand(10, 10).astype('float32')
c_data = np.random.rand(10, 10).astype('float32')
d_data = np.random.rand(10, 10).astype('float32')
tmp_add = a_data + b_data
tmp_mul = np.multiply(c_data, d_data)
ref_res = np.subtract(tmp_add, tmp_mul)

fallback_device = tvm.context("cpu")
target = {"cpu": "llvm", dev: tgt}
cpu_ctx = fallback_device
dev_ctx = tvm.context(dev)

def annotated():
add = relay.add(a, b)
_add = relay.annotation.on_device(add, dev_ctx)
mul = relay.multiply(c, d)
_mul = relay.annotation.on_device(mul, cpu_ctx)
sub = relay.subtract(add, mul)
_sub = relay.annotation.on_device(sub, dev_ctx)
func = relay.Function([a, b, c, d],
relay.Tuple(tvm.convert([_add, _mul,
_sub, sub])))
func = relay.ir_pass.infer_type(func)
func = relay.ir_pass.rewrite_annotated_ops(func,
dev_ctx.device_type)
func = relay.ir_pass.infer_type(func)
return relay.Function(relay.ir_pass.free_vars(func.body[3]),
func.body[3])

def expected():
add = relay.add(a, b)
mul = relay.multiply(c, d)
copy_mul_sub = relay.device_copy(mul, cpu_ctx, dev_ctx)
sub = relay.subtract(add, copy_mul_sub)
func = relay.Function([a, b, c, d], sub)
return func

annotated_func = annotated()
expected_func = expected()
expected_index = [2, 2, 2, 1, 1, 1, 2, 2]
check_annotated_graph(annotated_func, expected_func)
params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data}
config = {"opt_level": 0}
config["fallback_device"] = fallback_device
with relay.build_config(**config):
graph, lib, params = relay.build(annotated_func, target, params=params)
contexts = [tvm.cpu(0), tvm.context(dev)]
graph_json = json.loads(graph)
if "device_index" in graph_json["attrs"]:
device_index = graph_json["attrs"]["device_index"][1]
assert device_index == expected_index
mod = graph_runtime.create(graph, lib, contexts)
mod.set_input(**params)
mod.run()
res = mod.get_output(0).asnumpy()
tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5)

def test_check_run():
for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"),
("opencl", str(tvm.target.intel_graphics()))]:
("opencl", str(tvm.target.intel_graphics()))]:
if not tvm.module.enabled(dev):
print("Skip test because %s is not enabled." % dev)
continue
test_fuse_log_add(dev, tgt)
test_fuse_all(dev, tgt)
test_fallback_exp(dev, tgt)
test_fallback_all_operators(dev, tgt)

run_fusible_network(dev, tgt)
run_unpropagatable_graph(dev, tgt)


if __name__ == "__main__":
test_redundant_annotation()
test_annotate_all()
test_annotate_none()
test_conv_network()
test_fusible_network()
test_check_run()

0 comments on commit 113a325

Please sign in to comment.