diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 61ec28179eee..41833c4d4aff 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -334,6 +334,13 @@ class RelayBuildModule : public runtime::ModuleNode { // Fuse the operations if it is needed. relay_module = transform::FuseOps()(relay_module); relay_module = transform::InferType()(relay_module); + // Inline the functions that have been lifted by the module scope. + // + // TODO(@zhiics) Note that we need to be careful about the subgraphs with + // global function calls. We should make sure that these callees are also + // inline functions. However, this should be very unlikely for accelerators + // and vendor-provided libraries. So we don't handle for now. + relay_module = transform::Inline()(relay_module); CHECK(relay_module.defined()); return relay_module; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 2129b64a8b44..fc52a8e939d4 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -921,6 +921,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::LambdaLift()); pass_seqs.push_back(transform::InlinePrimitives()); + // Inline the functions that are lifted to the module scope. We perform this + // pass after all other optimization passes but before the memory allocation + // pass. This is because memory allocation pass will insert `invoke_tvm_op` + // and we use these ops to invoke the symbols in the module generated by + // external codegen. + pass_seqs.push_back(transform::Inline()); + // Manifest the allocations. pass_seqs.push_back(transform::ManifestAlloc(this->target_host_)); // Compute away possibly introduced constant computation. diff --git a/src/relay/backend/vm/inline_primitives.cc b/src/relay/backend/vm/inline_primitives.cc index 1d6ba4a46c61..25a9bcd38416 100644 --- a/src/relay/backend/vm/inline_primitives.cc +++ b/src/relay/backend/vm/inline_primitives.cc @@ -122,6 +122,7 @@ struct PrimitiveInliner : ExprMutator { auto global = pair.first; auto base_func = pair.second; if (auto* n = base_func.as()) { + if (!n->UseDefaultCompiler()) continue; auto func = GetRef(n); DLOG(INFO) << "Before inlining primitives: " << global diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 1cf671a1999b..5cf66c5807b0 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -189,6 +189,7 @@ class LambdaLifter : public ExprMutator { auto glob_funcs = module_->functions; for (auto pair : glob_funcs) { if (auto* n = pair.second.as()) { + if (!n->UseDefaultCompiler()) continue; auto func = GetRef(n); func = FunctionNode::make(func->params, VisitExpr(func->body), diff --git a/src/relay/pass/partition_graph.cc b/src/relay/pass/partition_graph.cc index e58bf61aa496..5e3aaf74b84d 100644 --- a/src/relay/pass/partition_graph.cc +++ b/src/relay/pass/partition_graph.cc @@ -110,6 +110,8 @@ class AnnotationChecker : public ExprVisitor { */ class Partitioner : public ExprMutator { public: + explicit Partitioner(const IRModule& module) : module_(module) {} + std::shared_ptr GetSubgraph(const Expr node) { for (auto candidate : this->subgraphs_) { if (candidate->nodes.find(node) != candidate->nodes.end()) { @@ -163,8 +165,10 @@ class Partitioner : public ExprMutator { // Replace the begin annotation with an external call input variable. auto compiler_attrs = call->attrs.as(); + // The type of the created variable is the same as the compiler_begin + // node. auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++), - input_expr->checked_type_); + call->checked_type_); // Find the corresponding subgraph and add the argument. auto subgraph = GetSubgraph(GetRef(call)); @@ -207,16 +211,24 @@ class Partitioner : public ExprMutator { } auto subgraph_func = - FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs()); + FunctionNode::make(params, input, call->checked_type_, {}, Attrs()); - Expr arg0 = call->args[0]; std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id); subgraph_func = FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name)); subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1)); subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler, tvm::tir::StringImmNode::make(compiler_attrs->compiler)); - return CallNode::make(subgraph_func, args); + subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1)); + CHECK(!module_->ContainGlobalVar(name)) + << "Global function " << name << " is already existed"; + GlobalVar glob_func(name); + module_->Add(glob_func, subgraph_func); + // The return type of callnode is the same as the type of the + // compiler_end node. + auto ret = CallNode::make(glob_func, args); + ret->checked_type_ = call->checked_type_; + return std::move(ret); } } @@ -330,50 +342,39 @@ class Partitioner : public ExprMutator { } } + IRModule Partition() { + auto glob_funcs = module_->functions; + for (const auto& pair : glob_funcs) { + if (auto* fn = pair.second.as()) { + auto func = GetRef(fn); + func = FunctionNode::make(func->params, + VisitExpr(func->body), + func->ret_type, + func->type_params, + func->attrs); + module_->Update(pair.first, func); + } + } + return module_; + } + private: int var_id_{0}; int subgraph_id_{0}; std::unordered_set> subgraphs_; + IRModule module_; }; -/*! - * \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to - * the same codegen backend. This reduces rounds trips between TVM and external - * backends. Likely we can borrow some ideas from operator fusion. - * - * For example, sg1 and sg2 should be combined if they belong to the same - * codegen tool in the following case. - * - * op1 - * / \ - * sg1 sg2 - * - * | - * \|/ - * - * op1 - * | - * sg1_sg2 - * - * where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two - * inputs that obtained from the tuple. - */ - -Expr PartitionGraph(const Expr& expr) { - Partitioner part; - return part.Mutate(expr); -} - } // namespace partitioning namespace transform { Pass PartitionGraph() { - runtime::TypedPackedFunc part_func = - [=](Function f, IRModule m, PassContext pc) { - return Downcast(partitioning::PartitionGraph(f)); + runtime::TypedPackedFunc part_func = + [=](IRModule m, PassContext pc) { + return partitioning::Partitioner(m).Partition(); }; - auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {}); + auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {}); return Sequential({partitioned, InferType()}); } diff --git a/src/relay/pass/to_a_normal_form.cc b/src/relay/pass/to_a_normal_form.cc index 9322e490d6a3..c75afd12e9dc 100644 --- a/src/relay/pass/to_a_normal_form.cc +++ b/src/relay/pass/to_a_normal_form.cc @@ -298,6 +298,9 @@ IRModule ToANormalForm(const IRModule& m) { auto funcs = m->functions; for (const auto& it : funcs) { CHECK_EQ(FreeVars(it.second).size(), 0); + if (const auto* n = it.second.as()) { + if (!n->UseDefaultCompiler()) continue; + } Expr ret = TransformF([&](const Expr& e) { return ToANormalFormAux(e); diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index 9c3228f4ff48..9537c9b6a3e6 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -18,14 +18,12 @@ import os import sys import numpy as np -import pytest import tvm -from tvm import te import tvm.relay.testing -import tvm.relay.transform as transform from tvm import relay from tvm import runtime +from tvm.relay import transform from tvm.contrib import util from tvm.relay.annotation import compiler_begin, compiler_end from tvm.relay.expr_functor import ExprMutator @@ -189,7 +187,7 @@ def update_lib(lib): return lib def check_vm_result(): - with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + with relay.build_config(opt_level=3): exe = relay.vm.compile(mod, target=target, params=params) code, lib = exe.save() lib = update_lib(lib) @@ -200,7 +198,7 @@ def check_vm_result(): tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol) def check_graph_runtime_result(): - with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]): + with relay.build_config(opt_level=3): json, lib, param = relay.build(mod, target=target, params=params) lib = update_lib(lib) rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx) @@ -297,6 +295,7 @@ def visit_call(self, call): def test_extern_ccompiler_default_ops(): def expected(): + mod = tvm.IRModule() x = relay.var("x", shape=(8, 8)) y = relay.var("y", shape=(8, 8)) x0 = relay.var("x0", shape=(8, 8)) @@ -305,11 +304,14 @@ def expected(): # Function that uses C compiler func = relay.Function([x0, y0], add) func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) func = func.set_attribute("Compiler", tvm.tir.StringImm("ccompiler")) func = func.set_attribute("ExternalSymbol", tvm.tir.StringImm("ccompiler_0")) - add_call = relay.Call(func, [x, y]) + glb_0 = relay.GlobalVar("ccompiler_0") + mod[glb_0] = func + add_call = relay.Call(glb_0, [x, y]) # Function that uses default compiler. Ops are fused in this function. p0 = relay.var("p0", shape=(8, 8)) log = relay.log(p0) @@ -320,7 +322,6 @@ def expected(): tvm.tir.IntImm("int32", 1)) fused_call = relay.Call(fused_func, [add_call]) main = relay.Function([x, y], fused_call) - mod = tvm.IRModule() mod["main"] = main return mod @@ -371,28 +372,65 @@ def test_extern_dnnl(): dtype = 'float32' ishape = (1, 32, 14, 14) w1shape = (32, 1, 3, 3) - data = relay.var('data', shape=(ishape), dtype=dtype) - weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype) - depthwise_conv2d_1 = relay.nn.conv2d(data, - weight1, - kernel_size=(3, 3), - padding=(1, 1), - groups=32) - depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, - weight1, - kernel_size=(3, 3), - padding=(1, 1), - groups=32) - out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) - - f = relay.Function([data, weight1], out) + + def expected(): + data0 = relay.var("data", shape=(ishape), dtype=dtype) + input0 = relay.var("input0", shape=(w1shape), dtype=dtype) + input1 = relay.var("input1", shape=(w1shape), dtype=dtype) + depthwise_conv2d_1 = relay.nn.conv2d(data0, + input0, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + input1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + + func = relay.Function([data0, input0, input1], out) + func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1)) + func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1)) + func = func.set_attribute("Compiler", tvm.tir.StringImm("dnnl")) + func = func.set_attribute("ExternalSymbol", + tvm.tir.StringImm("dnnl_0")) + glb_var = relay.GlobalVar("dnnl_0") + mod = tvm.IRModule() + mod[glb_var] = func + + data = relay.var("data", shape=(ishape), dtype=dtype) + weight = relay.var("input", shape=(w1shape), dtype=dtype) + main_f = relay.Function([data, weight], glb_var(data, weight, weight)) + mod["main"] = main_f + + return mod + + def get_func(): + data = relay.var("data", shape=(ishape), dtype=dtype) + weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype) + depthwise_conv2d_1 = relay.nn.conv2d(data, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1, + weight1, + kernel_size=(3, 3), + padding=(1, 1), + groups=32) + out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2) + + return relay.Function([data, weight1], out) mod = tvm.IRModule() - mod['main'] = WholeGraphAnnotator('dnnl').visit(f) + mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func()) mod = transform.PartitionGraph()(mod) + assert relay.alpha_equal(mod, expected()) + ref_mod = tvm.IRModule() - ref_mod['main'] = f + ref_mod["main"] = get_func() i_data = np.random.uniform(0, 1, ishape).astype(dtype) w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)