Skip to content

Commit

Permalink
outline and inline lifted functions for external codegen
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 5, 2020
1 parent f63b249 commit b97bbc6
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 60 deletions.
7 changes: 7 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,13 @@ class RelayBuildModule : public runtime::ModuleNode {
// Fuse the operations if it is needed.
relay_module = transform::FuseOps()(relay_module);
relay_module = transform::InferType()(relay_module);
// Inline the functions that have been lifted by the module scope.
//
// TODO(@zhiics) Note that we need to be careful about the subgraphs with
// global function calls. We should make sure that these callees are also
// inline functions. However, this should be very unlikely for accelerators
// and vendor-provided libraries. So we don't handle for now.
relay_module = transform::Inline()(relay_module);
CHECK(relay_module.defined());

return relay_module;
Expand Down
7 changes: 7 additions & 0 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,13 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe
pass_seqs.push_back(transform::LambdaLift());
pass_seqs.push_back(transform::InlinePrimitives());

// Inline the functions that are lifted to the module scope. We perform this
// pass after all other optimization passes but before the memory allocation
// pass. This is because memory allocation pass will insert `invoke_tvm_op`
// and we use these ops to invoke the symbols in the module generated by
// external codegen.
pass_seqs.push_back(transform::Inline());

// Manifest the allocations.
pass_seqs.push_back(transform::ManifestAlloc(this->target_host_));
// Compute away possibly introduced constant computation.
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ struct PrimitiveInliner : ExprMutator {
auto global = pair.first;
auto base_func = pair.second;
if (auto* n = base_func.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n);

DLOG(INFO) << "Before inlining primitives: " << global
Expand Down
1 change: 1 addition & 0 deletions src/relay/backend/vm/lambda_lift.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ class LambdaLifter : public ExprMutator {
auto glob_funcs = module_->functions;
for (auto pair : glob_funcs) {
if (auto* n = pair.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
auto func = GetRef<Function>(n);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
Expand Down
73 changes: 37 additions & 36 deletions src/relay/pass/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ class AnnotationChecker : public ExprVisitor {
*/
class Partitioner : public ExprMutator {
public:
explicit Partitioner(const IRModule& module) : module_(module) {}

std::shared_ptr<Subgraph> GetSubgraph(const Expr node) {
for (auto candidate : this->subgraphs_) {
if (candidate->nodes.find(node) != candidate->nodes.end()) {
Expand Down Expand Up @@ -163,8 +165,10 @@ class Partitioner : public ExprMutator {

// Replace the begin annotation with an external call input variable.
auto compiler_attrs = call->attrs.as<CompilerAttrs>();
// The type of the created variable is the same as the compiler_begin
// node.
auto var = VarNode::make(compiler_attrs->compiler + "_input" + std::to_string(var_id_++),
input_expr->checked_type_);
call->checked_type_);

// Find the corresponding subgraph and add the argument.
auto subgraph = GetSubgraph(GetRef<Call>(call));
Expand Down Expand Up @@ -207,16 +211,24 @@ class Partitioner : public ExprMutator {
}

auto subgraph_func =
FunctionNode::make(params, input, call->args[0]->checked_type_, {}, Attrs());
FunctionNode::make(params, input, call->checked_type_, {}, Attrs());

Expr arg0 = call->args[0];
std::string name = compiler_attrs->compiler + "_" + std::to_string(subgraph->id);
subgraph_func =
FunctionSetAttr(subgraph_func, attr::kExternalSymbol, tir::StringImmNode::make(name));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kPrimitive, tvm::Integer(1));
subgraph_func = FunctionSetAttr(subgraph_func, attr::kCompiler,
tvm::tir::StringImmNode::make(compiler_attrs->compiler));
return CallNode::make(subgraph_func, args);
subgraph_func = FunctionSetAttr(subgraph_func, attr::kInline, tvm::Integer(1));
CHECK(!module_->ContainGlobalVar(name))
<< "Global function " << name << " is already existed";
GlobalVar glob_func(name);
module_->Add(glob_func, subgraph_func);
// The return type of callnode is the same as the type of the
// compiler_end node.
auto ret = CallNode::make(glob_func, args);
ret->checked_type_ = call->checked_type_;
return std::move(ret);
}
}

Expand Down Expand Up @@ -330,50 +342,39 @@ class Partitioner : public ExprMutator {
}
}

IRModule Partition() {
auto glob_funcs = module_->functions;
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
func = FunctionNode::make(func->params,
VisitExpr(func->body),
func->ret_type,
func->type_params,
func->attrs);
module_->Update(pair.first, func);
}
}
return module_;
}

private:
int var_id_{0};
int subgraph_id_{0};
std::unordered_set<std::shared_ptr<Subgraph>> subgraphs_;
IRModule module_;
};

/*!
* \brief TODO(@zhiics, @comaniac) Combine parallel regions that belong to
* the same codegen backend. This reduces rounds trips between TVM and external
* backends. Likely we can borrow some ideas from operator fusion.
*
* For example, sg1 and sg2 should be combined if they belong to the same
* codegen tool in the following case.
*
* op1
* / \
* sg1 sg2
*
* |
* \|/
*
* op1
* |
* sg1_sg2
*
* where the return type of the new subgraph sg1_sg2 is a tuple, and op1 has two
* inputs that obtained from the tuple.
*/

Expr PartitionGraph(const Expr& expr) {
Partitioner part;
return part.Mutate(expr);
}

} // namespace partitioning

namespace transform {

Pass PartitionGraph() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> part_func =
[=](Function f, IRModule m, PassContext pc) {
return Downcast<Function>(partitioning::PartitionGraph(f));
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> part_func =
[=](IRModule m, PassContext pc) {
return partitioning::Partitioner(m).Partition();
};
auto partitioned = CreateFunctionPass(part_func, 0, "PartitionGraph", {});
auto partitioned = CreateModulePass(part_func, 0, "PartitionGraph", {});
return Sequential({partitioned, InferType()});
}

Expand Down
3 changes: 3 additions & 0 deletions src/relay/pass/to_a_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,9 @@ IRModule ToANormalForm(const IRModule& m) {
auto funcs = m->functions;
for (const auto& it : funcs) {
CHECK_EQ(FreeVars(it.second).size(), 0);
if (const auto* n = it.second.as<FunctionNode>()) {
if (!n->UseDefaultCompiler()) continue;
}
Expr ret =
TransformF([&](const Expr& e) {
return ToANormalFormAux(e);
Expand Down
86 changes: 62 additions & 24 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
import os
import sys
import numpy as np
import pytest

import tvm
from tvm import te
import tvm.relay.testing
import tvm.relay.transform as transform
from tvm import relay
from tvm import runtime
from tvm.relay import transform
from tvm.contrib import util
from tvm.relay.annotation import compiler_begin, compiler_end
from tvm.relay.expr_functor import ExprMutator
Expand Down Expand Up @@ -189,7 +187,7 @@ def update_lib(lib):
return lib

def check_vm_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
with relay.build_config(opt_level=3):
exe = relay.vm.compile(mod, target=target, params=params)
code, lib = exe.save()
lib = update_lib(lib)
Expand All @@ -200,7 +198,7 @@ def check_vm_result():
tvm.testing.assert_allclose(out.asnumpy(), result, rtol=tol, atol=tol)

def check_graph_runtime_result():
with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
with relay.build_config(opt_level=3):
json, lib, param = relay.build(mod, target=target, params=params)
lib = update_lib(lib)
rt_mod = tvm.contrib.graph_runtime.create(json, lib, ctx)
Expand Down Expand Up @@ -297,6 +295,7 @@ def visit_call(self, call):

def test_extern_ccompiler_default_ops():
def expected():
mod = tvm.IRModule()
x = relay.var("x", shape=(8, 8))
y = relay.var("y", shape=(8, 8))
x0 = relay.var("x0", shape=(8, 8))
Expand All @@ -305,11 +304,14 @@ def expected():
# Function that uses C compiler
func = relay.Function([x0, y0], add)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler",
tvm.tir.StringImm("ccompiler"))
func = func.set_attribute("ExternalSymbol",
tvm.tir.StringImm("ccompiler_0"))
add_call = relay.Call(func, [x, y])
glb_0 = relay.GlobalVar("ccompiler_0")
mod[glb_0] = func
add_call = relay.Call(glb_0, [x, y])
# Function that uses default compiler. Ops are fused in this function.
p0 = relay.var("p0", shape=(8, 8))
log = relay.log(p0)
Expand All @@ -320,7 +322,6 @@ def expected():
tvm.tir.IntImm("int32", 1))
fused_call = relay.Call(fused_func, [add_call])
main = relay.Function([x, y], fused_call)
mod = tvm.IRModule()
mod["main"] = main
return mod

Expand Down Expand Up @@ -371,28 +372,65 @@ def test_extern_dnnl():
dtype = 'float32'
ishape = (1, 32, 14, 14)
w1shape = (32, 1, 3, 3)
data = relay.var('data', shape=(ishape), dtype=dtype)
weight1 = relay.var('weight1', shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)

f = relay.Function([data, weight1], out)

def expected():
data0 = relay.var("data", shape=(ishape), dtype=dtype)
input0 = relay.var("input0", shape=(w1shape), dtype=dtype)
input1 = relay.var("input1", shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data0,
input0,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
input1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)

func = relay.Function([data0, input0, input1], out)
func = func.set_attribute("Primitive", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Inline", tvm.tir.IntImm("int32", 1))
func = func.set_attribute("Compiler", tvm.tir.StringImm("dnnl"))
func = func.set_attribute("ExternalSymbol",
tvm.tir.StringImm("dnnl_0"))
glb_var = relay.GlobalVar("dnnl_0")
mod = tvm.IRModule()
mod[glb_var] = func

data = relay.var("data", shape=(ishape), dtype=dtype)
weight = relay.var("input", shape=(w1shape), dtype=dtype)
main_f = relay.Function([data, weight], glb_var(data, weight, weight))
mod["main"] = main_f

return mod

def get_func():
data = relay.var("data", shape=(ishape), dtype=dtype)
weight1 = relay.var("weight1", shape=(w1shape), dtype=dtype)
depthwise_conv2d_1 = relay.nn.conv2d(data,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
depthwise_conv2d_2 = relay.nn.conv2d(depthwise_conv2d_1,
weight1,
kernel_size=(3, 3),
padding=(1, 1),
groups=32)
out = relay.add(depthwise_conv2d_1, depthwise_conv2d_2)

return relay.Function([data, weight1], out)

mod = tvm.IRModule()
mod['main'] = WholeGraphAnnotator('dnnl').visit(f)
mod["main"] = WholeGraphAnnotator("dnnl").visit(get_func())
mod = transform.PartitionGraph()(mod)

assert relay.alpha_equal(mod, expected())

ref_mod = tvm.IRModule()
ref_mod['main'] = f
ref_mod["main"] = get_func()

i_data = np.random.uniform(0, 1, ishape).astype(dtype)
w1_data = np.random.uniform(0, 1, w1shape).astype(dtype)
Expand Down

0 comments on commit b97bbc6

Please sign in to comment.