diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 46f4268cc970..0139cc912849 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -334,9 +334,9 @@ class AnnotatationVisitor : private ExprVisitor { * -Pass 1: Propagating the source device type to ops in a bottom-up way to the * ancestors until encountering another copy op. For example, this way * provides add, x, and y device types from the copy operator, `copy1`. - * -Pass 2: Propagating the destination device type of "the last" copy op in a - * top-down manner to the nodes on the output paths. For instance, - * this offers `subtract` and `exp` the same device type as `copy3`. + * -Pass 2: Propagating the destination device type of "the last" copy op to the + * remain nodes. For instance, this offers `subtract` and `exp` the + * same device type as `copy3`. */ class DeviceInfo { @@ -371,17 +371,22 @@ class DeviceInfo { } void VisitExpr_(const ConstantNode* cn) final { - post_dfs_order_.push_back(cn); + post_dfs_order_.push_back(std::make_pair(cn, has_copy_)); } void VisitExpr_(const CallNode* call) final { // Skip annotation nodes. if (!IsOnDeviceNode(call)) { - ExprVisitor::VisitExpr_(call); - post_dfs_order_.push_back(call); - if (GetDeviceCopyNode(call)) { num_device_copy_ops_++; + bool has_copy_prev = has_copy_; + has_copy_ = true; + ExprVisitor::VisitExpr_(call); + post_dfs_order_.push_back(std::make_pair(call, has_copy_)); + has_copy_ = has_copy_prev; + } else { + ExprVisitor::VisitExpr_(call); + post_dfs_order_.push_back(std::make_pair(call, has_copy_)); } } } @@ -393,23 +398,27 @@ class DeviceInfo { void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); - post_dfs_order_.push_back(op); + std::make_pair(op, has_copy_); } - void VisitExpr_(const VarNode* vn) final { post_dfs_order_.push_back(vn); } + void VisitExpr_(const VarNode* vn) final { + post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); + } void VisitExpr_(const LetNode* ln) final { ExprVisitor::VisitExpr_(ln); - post_dfs_order_.push_back(ln); + post_dfs_order_.push_back(std::make_pair(ln, has_copy_)); } void VisitExpr_(const IfNode* in) final { ExprVisitor::VisitExpr_(in); - post_dfs_order_.push_back(in); + post_dfs_order_.push_back(std::make_pair(in, has_copy_)); } + int num_device_copy_ops_{0}; - std::vector post_dfs_order_; + bool has_copy_ = false; + std::vector> post_dfs_order_; friend DeviceInfo; }; @@ -435,46 +444,41 @@ class DeviceInfo { void PropagateDeviceId() { // Bottom-up propagation. - BottomUpPropagation(); - // Top-down propagation. - TopDownPropagation(); + int out_dev_type = BottomUpPropagation(); + // propagation for remained nodes. + FillPropagation(out_dev_type); } - void BottomUpPropagation() { + int BottomUpPropagation() { const CallNode* last_copy_node = nullptr; int cur_dev_type = -1; + int out_dev_type = -1; for (auto it = post_visitor_.post_dfs_order_.crbegin(); it != post_visitor_.post_dfs_order_.crend(); ++it) { - if (const auto* node = GetDeviceCopyNode(*it)) { + if (const auto* node = GetDeviceCopyNode(it->first)) { last_copy_node = dynamic_cast(node); const auto* attrs = last_copy_node->attrs.as(); cur_dev_type = attrs->src_dev_type; - device_map_.Set(GetRef(*it), attrs->dst_dev_type); + if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type; + if (it->second) device_map_.Set(GetRef(it->first), + attrs->dst_dev_type); } else if (last_copy_node) { - Expr expr = GetRef(*it); + Expr expr = GetRef(it->first); CHECK_EQ(device_map_.count(expr), 0U); - device_map_.Set(expr, cur_dev_type); + if (it->second) device_map_.Set(expr, cur_dev_type); } } + return out_dev_type; } - void TopDownPropagation() { - const CallNode* last_copy_node = nullptr; - int cur_dev_type = -1; + void FillPropagation(int out_dev_type) { for (const auto& it : post_visitor_.post_dfs_order_) { - if (const auto* node = GetDeviceCopyNode(it)) { - last_copy_node = dynamic_cast(node); - const auto* attrs = last_copy_node->attrs.as(); - cur_dev_type = attrs->dst_dev_type; - } else if (last_copy_node) { - Expr expr = GetRef(it); - if (device_map_.count(expr) == 0) { - device_map_.Set(expr, cur_dev_type); - } - } + Expr expr = GetRef(it.first); + if (!it.second) device_map_.Set(expr, out_dev_type); } } + PostDfsOrderVisitor post_visitor_; Map device_map_; }; @@ -503,3 +507,4 @@ TVM_REGISTER_API("relay._ir_pass.CollectDeviceAnnotationOps") } // namespace relay } // namespace tvm + diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index c55a9fb2dd85..04081e06735b 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -231,7 +231,7 @@ def check_storage_and_device_types(): check_storage_and_device_types() -def test_fusible_network(): +def run_fusible_network(dev, tgt): R""" The network is as following: x y \ / @@ -417,20 +417,96 @@ def test_fallback_all_operators(device, tgt): check_annotated_graph(annotated_func, expected_func) test_runtime(target, device, annotated_func) + + test_fuse_log_add(dev, tgt) + test_fuse_all(dev, tgt) + test_fallback_exp(dev, tgt) + test_fallback_all_operators(dev, tgt) + +def run_unpropagatable_graph(dev, tgt): + R""" The network is as following: + a b c d + \ / \ / + add mul + \ / + subtract + """ + + a = relay.var("a", shape=(10, 10)) + b = relay.var("b", shape=(10, 10)) + c = relay.var("c", shape=(10, 10)) + d = relay.var("d", shape=(10, 10)) + a_data = np.random.rand(10, 10).astype('float32') + b_data = np.random.rand(10, 10).astype('float32') + c_data = np.random.rand(10, 10).astype('float32') + d_data = np.random.rand(10, 10).astype('float32') + tmp_add = a_data + b_data + tmp_mul = np.multiply(c_data, d_data) + ref_res = np.subtract(tmp_add, tmp_mul) + + fallback_device = tvm.context("cpu") + target = {"cpu": "llvm", dev: tgt} + cpu_ctx = fallback_device + dev_ctx = tvm.context(dev) + + def annotated(): + add = relay.add(a, b) + _add = relay.annotation.on_device(add, dev_ctx) + mul = relay.multiply(c, d) + _mul = relay.annotation.on_device(mul, cpu_ctx) + sub = relay.subtract(add, mul) + _sub = relay.annotation.on_device(sub, dev_ctx) + func = relay.Function([a, b, c, d], + relay.Tuple(tvm.convert([_add, _mul, + _sub, sub]))) + func = relay.ir_pass.infer_type(func) + func = relay.ir_pass.rewrite_annotated_ops(func, + dev_ctx.device_type) + func = relay.ir_pass.infer_type(func) + return relay.Function(relay.ir_pass.free_vars(func.body[3]), + func.body[3]) + + def expected(): + add = relay.add(a, b) + mul = relay.multiply(c, d) + copy_mul_sub = relay.device_copy(mul, cpu_ctx, dev_ctx) + sub = relay.subtract(add, copy_mul_sub) + func = relay.Function([a, b, c, d], sub) + return func + + annotated_func = annotated() + expected_func = expected() + expected_index = [2, 2, 2, 1, 1, 1, 2, 2] + check_annotated_graph(annotated_func, expected_func) + params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data} + config = {"opt_level": 0} + config["fallback_device"] = fallback_device + with relay.build_config(**config): + graph, lib, params = relay.build(annotated_func, target, params=params) + contexts = [tvm.cpu(0), tvm.context(dev)] + graph_json = json.loads(graph) + if "device_index" in graph_json["attrs"]: + device_index = graph_json["attrs"]["device_index"][1] + assert device_index == expected_index + mod = graph_runtime.create(graph, lib, contexts) + mod.set_input(**params) + mod.run() + res = mod.get_output(0).asnumpy() + tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) + +def test_check_run(): for dev, tgt in [("opencl", "opencl"), ("cuda", "cuda"), - ("opencl", str(tvm.target.intel_graphics()))]: + ("opencl", str(tvm.target.intel_graphics()))]: if not tvm.module.enabled(dev): print("Skip test because %s is not enabled." % dev) continue - test_fuse_log_add(dev, tgt) - test_fuse_all(dev, tgt) - test_fallback_exp(dev, tgt) - test_fallback_all_operators(dev, tgt) - + run_fusible_network(dev, tgt) + run_unpropagatable_graph(dev, tgt) + if __name__ == "__main__": test_redundant_annotation() test_annotate_all() test_annotate_none() test_conv_network() - test_fusible_network() + test_check_run()