Skip to content

Commit

Permalink
[Relay][WIP] ConvertLayout pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Nov 14, 2019
1 parent 5d66e7a commit a066107
Show file tree
Hide file tree
Showing 9 changed files with 459 additions and 22 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ TVM_DLL Pass CanonicalizeOps();
*
* \return The pass.
*/
TVM_DLL Pass AlterOpLayout();
TVM_DLL Pass AlterOpLayout(const std::string& alter_map_attr_name = "FTVMAlterOpLayout");

/*!
* \brief Legalizes an expr with another expression.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Relay core operators."""
# operator defs
from .op import get, register, register_schedule, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, register_legalize, \
register_pattern, register_alter_op_layout, register_convert_op_layout, register_legalize, \
schedule_injective, Op, OpPattern, debug

# Operators
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
return register(op_name, "FTVMAlterOpLayout", alter_layout, level)


def register_convert_op_layout(op_name, convert_layout=None, level=10):
"""Register alter op layout function for an op
Parameters
----------
op_name : str
The name of the operator
alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for changing the layout or replacing the operator
level : int
The priority level
"""
return register(op_name, "FTVMConvertOpLayout", convert_layout, level)


def register_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for an op
Expand Down
69 changes: 67 additions & 2 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,18 +430,23 @@ def CombineParallelDense(min_num_branches=3):
return _transform.CombineParallelDense(min_num_branches)


def AlterOpLayout():
def AlterOpLayout(alter_map_attr_name="FTVMAlterOpLayout"):
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
This pass can be used for computing convolution in custom layouts or
other general weight pre-transformation.
Parameters
----------
alter_map_attr_name : str
The Op's attr name which corresponds to the alter rule function.
Returns
-------
ret : tvm.relay.Pass
The registered pass that alters the layout of operators.
"""
return _transform.AlterOpLayout()
return _transform.AlterOpLayout(alter_map_attr_name)


def Legalize(legalize_map_attr_name="FTVMLegalize"):
Expand Down Expand Up @@ -979,3 +984,63 @@ def visit_var(self, var):
else:
return var
return ChangeBatchMutator().visit(func)

@function_pass(opt_level=1)
class ConvertLayout:
"""
Given a dest layout, this pass transforms the expr such that most of the ops input data layout
is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one at
the start and one at the end.
This pass is not a part of relay.build and is expected to be called between framework-relay
parser and relay.build call. This is very helpful for hardware backends that support/prefer only
type of data layout.
RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can just
overwrite conv2d alter op layout. Most of the other operators try to adapt to their input layout
using the InferCorrectLayout infrastructure.
Parameters
----------
dst_layout: str
The desired layout for the transformed expr.
Returns
-------
pass: FunctionPass
The pass.
"""

def __init__(self, dst_layout):
assert dst_layout == 'NCHW', \
'Currently, only NCHW layout conversion is supported.'

# Register convert layout function for conv2d op.
from tvm.relay.op import register_convert_op_layout
@register_convert_op_layout("nn.conv2d")
def alter_conv2d(attrs, inputs, tinfos):
data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout']
data, weight = inputs
if dst_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = dst_layout
new_attrs['kernel_layout'] = 'OIHW'

if data_layout == 'NHWC' and kernel_layout == 'HWIO':
# Convert (NHWC, HWIO) to (NCHW, OIHW)
return relay.nn.conv2d(data, weight, **new_attrs)
elif data_layout == 'NHWC' and kernel_layout == 'HWOI':
# Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d.
return relay.nn.conv2d(data, weight, **new_attrs)
return None
return None

def transform_function(self, func, mod, ctx):
cur_mod = relay.Module.from_expr(func)
cur_mod = CanonicalizeOps()(cur_mod)
cur_mod = AlterOpLayout("FTVMConvertOpLayout")(cur_mod)
cur_mod = FoldConstant()(cur_mod)
return cur_mod['main']
2 changes: 2 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ class RelayBuildModule : public runtime::ModuleNode {
} else {
relay_module = seq(relay_module);
}
auto func1 = relay_module->Lookup("main");
LOG(INFO) << AsText(func1, false);

// Handle heterogeneous compilation.
transform::PassContext pass_ctx = PassContext::Current();
Expand Down
15 changes: 8 additions & 7 deletions src/relay/ir/op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,14 @@ void OpRegistry::UpdateAttr(const std::string& key,
if (op_map->data_.size() <= index) {
op_map->data_.resize(index + 1, std::make_pair(TVMRetValue(), 0));
}
std::pair<TVMRetValue, int>& p = op_map->data_[index];
CHECK(p.second != plevel)
<< "Attribute " << key << " of operator " << this->name
<< " is already registered with same plevel=" << plevel;
if (p.second < plevel) {
op_map->data_[index] = std::make_pair(value, plevel);
}
op_map->data_[index] = std::make_pair(value, plevel);
// std::pair<TVMRetValue, int>& p = op_map->data_[index];
// CHECK(p.second != plevel)
// << "Attribute " << key << " of operator " << this->name
// << " is already registered with same plevel=" << plevel;
// if (p.second < plevel) {
// op_map->data_[index] = std::make_pair(value, plevel);
// }
}

// Frontend APIs
Expand Down
40 changes: 30 additions & 10 deletions src/relay/pass/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
// Call registered FTVMAlterOpLayout of an op
// Returns the altered expression
Call CallAlter(const Call& ref_call,
const std::vector<Expr>& new_args) {
static auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>("FTVMAlterOpLayout");
const std::vector<Expr>& new_args,
const std::string& alter_map_attr_name) {
auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>(alter_map_attr_name);
Op op = Downcast<Op>(ref_call->op);

Expr new_e;
Expand All @@ -213,9 +214,10 @@ Call CallAlter(const Call& ref_call,
return GetRef<Call>(new_call);
}

Expr AlterOpLayoutRewrite(const Call &ref_call,
const Array<Expr> &new_args,
const NodeRef& ctx) {
Expr Rewriter(const Call &ref_call,
const Array<Expr> &new_args,
const NodeRef& ctx,
const std::string& alter_map_attr_name) {
std::vector<LayoutAlternatedExpr> inputs;
std::vector<Expr> normal_new_args;
Array<Array<IndexExpr> > input_shapes;
Expand Down Expand Up @@ -289,7 +291,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
}

// new_op = alter(op)
Call new_call = CallAlter(ref_call, normal_new_args);
Call new_call = CallAlter(ref_call, normal_new_args, alter_map_attr_name);

// new_in2, new_out = op.infer(new_in)
if (new_call->op->IsInstance<OpNode>()) {
Expand Down Expand Up @@ -353,26 +355,44 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
}
}

Expr AlterOpLayoutRewrite(const Call &ref_call,
const Array<Expr> &new_args,
const NodeRef& ctx) {
return Rewriter(ref_call, new_args, ctx, "FTVMAlterOpLayout");
}

Expr ConvertOpLayoutRewrite(const Call &ref_call,
const Array<Expr> &new_args,
const NodeRef& ctx) {
return Rewriter(ref_call, new_args, ctx, "FTVMConvertOpLayout");
}

// Limiations:
// 1. the altered op should have the same number of arguments as the previous one
// 2. do not support nested tuple arguments
Expr AlterOpLayout(const Expr& expr) {
Expr AlterOpLayout(const Expr& expr, const std::string& alter_map_attr_name) {
TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>());
auto fcontext = [&](const Call& call) -> NodeRef{
return transformMemorizer;
};

return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext);
if (alter_map_attr_name == "FTVMAlterOpLayout") {
return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext);
} else if (alter_map_attr_name == "FTVMConvertOpLayout") {
return ForwardRewrite(expr, ConvertOpLayoutRewrite, fcontext);
}
LOG(FATAL) << "AlterOpLayout supports only 2 attr names - FTVMAlterOpLayout/FTVMConvertOpLayout";
return Expr();
}

} // namespace alter_op_layout

namespace transform {

Pass AlterOpLayout() {
Pass AlterOpLayout(const std::string& alter_map_attr_name) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f, alter_map_attr_name));
};
return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
{ir::StringImm::make("InferType")});
Expand Down
5 changes: 4 additions & 1 deletion tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
mod, params = relay.frontend.from_tflite(tflite_model,
shape_dict=shape_dict,
dtype_dict=dtype_dict)
# Convert to NCHW layout.
mod = relay.transform.ConvertLayout('NCHW')(mod)

with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)

Expand Down Expand Up @@ -151,7 +154,7 @@ def compare_tflite_with_tvm(in_data, in_name, input_tensors,
tflite_model_buffer = converter.convert()
tflite_output = run_tflite_graph(tflite_model_buffer, in_data)

for device in ["llvm"]:
for device in ["llvm", "cuda"]:
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
Expand Down
Loading

0 comments on commit a066107

Please sign in to comment.