diff --git a/src/relay/pass/device_annotation.cc b/src/relay/pass/device_annotation.cc index 0139cc912849..8807f6dd4cf4 100644 --- a/src/relay/pass/device_annotation.cc +++ b/src/relay/pass/device_annotation.cc @@ -485,7 +485,52 @@ class DeviceInfo { Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { RewriteAnnotation rewrote = RewriteAnnotation(); - return rewrote.Rewrite(expr, fallback_device); + Expr new_expr = rewrote.Rewrite(expr, fallback_device); + + // Remove OnDevice operators. Note that these operators are only present at the + // leaves after annotation. Therefore, we can simply reconstruct the + // Function/Expr by removing them directly. + if (const FunctionNode* fn = new_expr.as()) { + auto params = fn->params; + auto body = fn->body; + std::vector new_body; + if (const TupleNode* tuple = body.as()) { + for (const auto& field : tuple->fields) { + if (!IsOnDeviceNode(field.operator->())) { + new_body.push_back(field); + } + } + CHECK_GT(new_body.size(), 0U); + if (new_body.size() == 1) { + return FunctionNode::make(params, new_body[0], Type(nullptr), + fn->type_params, fn->attrs); + } else if (tuple->fields.size() == new_body.size()) { + return new_expr; + } else { + Tuple tuple_body = TupleNode::make(new_body); + return FunctionNode::make(params, tuple_body, Type(nullptr), + fn->type_params, fn->attrs); + } + } else { + return new_expr; + } + } else if (const TupleNode* tuple = new_expr.as()) { + std::vector new_fields; + for (const auto& field : tuple->fields) { + if (!IsOnDeviceNode(field.operator->())) { + new_fields.push_back(field); + } + } + CHECK_GT(new_fields.size(), 0U); + if (tuple->fields.size() == new_fields.size()) { + return new_fields.size() == 1 ? new_fields[0] : new_expr; + } else { + return new_fields.size() == 1 ? new_fields[0] + : TupleNode::make(new_fields); + } + } else { + return new_expr; + } } Map CollectDeviceInfo(const Expr& expr) { diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 9a77d2ffe856..98cf0f15446e 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -42,9 +42,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, ctx1.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[2]), - func.body[2]) + return func def expected(): add = relay.add(x, y) @@ -58,6 +56,35 @@ def expected(): assert relay.ir_pass.alpha_equal(annotated_func, expected_func) +def test_annotate_expr(): + ctx1 = tvm.context(1) + ctx2 = tvm.context(2) + x = relay.var("x", shape=(3,)) + y = relay.var("y", shape=(3,)) + z = relay.var("z", shape=(3,)) + + def annotated(): + add = relay.add(x, y) + _add = relay.annotation.on_device(add, ctx1) + sub = relay.subtract(add, z) + _sub = relay.annotation.on_device(sub, ctx2) + expr = relay.Tuple([sub, _add, _sub]) + expr = relay.ir_pass.infer_type(expr) + expr = relay.ir_pass.rewrite_annotated_ops(expr, + ctx1.device_type) + return expr + + def expected(): + add = relay.add(x, y) + copy_add_sub = relay.device_copy(add, ctx1, ctx2) + sub = relay.subtract(copy_add_sub, z) + return sub + + annotated_expr = relay.ir_pass.infer_type(annotated()) + expected_expr = relay.ir_pass.infer_type(expected()) + assert relay.ir_pass.graph_equal(annotated_expr, expected_expr) + + def test_annotate_all(): ctx1 = tvm.context(1) ctx2 = tvm.context(2) @@ -77,9 +104,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, ctx1.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[2]), - func.body[2]) + return func def expected(): add = relay.add(x, y) @@ -91,6 +116,7 @@ def expected(): expected_func = relay.ir_pass.infer_type(expected()) assert relay.ir_pass.alpha_equal(annotated_func, expected_func) + def test_annotate_none(): ctx1 = tvm.context(1) ctx2 = tvm.context(2) @@ -174,9 +200,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, tvm.context(3).device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[4]), - func.body[4]) + return func def expected(): conv2d_1 = relay.nn.conv2d( @@ -202,7 +226,7 @@ def expected(): kernel_size=(3, 3), padding=(1, 1)) - func = relay.Function([data1, weight, data2], conv2d_3) + func = relay.Function([data1, data2, weight], conv2d_3) return func def check_storage_and_device_types(): @@ -306,9 +330,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[2]), - func.body[2]) + return func def expected(): add = relay.add(x, y) @@ -358,9 +380,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, cpu_ctx.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[5]), - func.body[5]) + return func annotated_func = annotated() expected_func = get_func() @@ -386,9 +406,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, dev_ctx.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[1]), - func.body[1]) + return func def expected(): add = relay.add(x, y) @@ -462,9 +480,7 @@ def annotated(): func = relay.ir_pass.infer_type(func) func = relay.ir_pass.rewrite_annotated_ops(func, dev_ctx.device_type) - func = relay.ir_pass.infer_type(func) - return relay.Function(relay.ir_pass.free_vars(func.body[3]), - func.body[3]) + return func def expected(): add = relay.add(a, b) @@ -506,6 +522,7 @@ def test_check_run(): if __name__ == "__main__": test_redundant_annotation() + test_annotate_expr() test_annotate_all() test_annotate_none() test_conv_network()