Skip to content

Commit

Permalink
[Relay][WIP] ConvertLayout pass.
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed Nov 14, 2019
1 parent 10b77ef commit 1fb5cc4
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 57 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ TVM_DLL Pass CanonicalizeOps();
*
* \return The pass.
*/
TVM_DLL Pass AlterOpLayout();
TVM_DLL Pass AlterOpLayout(const std::string& alter_map_attr_name = "FTVMAlterOpLayout");

/*!
* \brief Legalizes an expr with another expression.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Relay core operators."""
# operator defs
from .op import get, register, register_schedule, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, register_legalize, \
register_pattern, register_alter_op_layout, register_convert_op_layout, register_legalize, \
schedule_injective, Op, OpPattern, debug

# Operators
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
return register(op_name, "FTVMAlterOpLayout", alter_layout, level)


def register_convert_op_layout(op_name, convert_layout=None, level=10):
"""Register alter op layout function for an op
Parameters
----------
op_name : str
The name of the operator
alter_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for changing the layout or replacing the operator
level : int
The priority level
"""
return register(op_name, "FTVMConvertOpLayout", convert_layout, level)


def register_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for an op
Expand Down
69 changes: 67 additions & 2 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,18 +430,23 @@ def CombineParallelDense(min_num_branches=3):
return _transform.CombineParallelDense(min_num_branches)


def AlterOpLayout():
def AlterOpLayout(alter_map_attr_name="FTVMAlterOpLayout"):
"""Alternate the layouts of operators or replace primitive operators with
other expressions.
This pass can be used for computing convolution in custom layouts or
other general weight pre-transformation.
Parameters
----------
alter_map_attr_name : str
The Op's attr name which corresponds to the alter rule function.
Returns
-------
ret : tvm.relay.Pass
The registered pass that alters the layout of operators.
"""
return _transform.AlterOpLayout()
return _transform.AlterOpLayout(alter_map_attr_name)


def Legalize(legalize_map_attr_name="FTVMLegalize"):
Expand Down Expand Up @@ -979,3 +984,63 @@ def visit_var(self, var):
else:
return var
return ChangeBatchMutator().visit(func)

@function_pass(opt_level=1)
class ConvertLayout:
"""
Given a dest layout, this pass transforms the expr such that most of the ops input data layout
is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one at
the start and one at the end.
This pass is not a part of relay.build and is expected to be called between framework-relay
parser and relay.build call. This is very helpful for hardware backends that support/prefer only
type of data layout.
RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can just
overwrite conv2d alter op layout. Most of the other operators try to adapt to their input layout
using the InferCorrectLayout infrastructure.
Parameters
----------
dst_layout: str
The desired layout for the transformed expr.
Returns
-------
pass: FunctionPass
The pass.
"""

def __init__(self, dst_layout):
assert dst_layout == 'NCHW', \
'Currently, only NCHW layout conversion is supported.'

# Register convert layout function for conv2d op.
from tvm.relay.op import register_convert_op_layout
@register_convert_op_layout("nn.conv2d")
def alter_conv2d(attrs, inputs, tinfos):
data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout']
data, weight = inputs
if dst_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = dst_layout
new_attrs['kernel_layout'] = 'OIHW'

if data_layout == 'NHWC' and kernel_layout == 'HWIO':
# Convert (NHWC, HWIO) to (NCHW, OIHW)
return relay.nn.conv2d(data, weight, **new_attrs)
elif data_layout == 'NHWC' and kernel_layout == 'HWOI':
# Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d.
return relay.nn.conv2d(data, weight, **new_attrs)
return None
return None

def transform_function(self, func, mod, ctx):
cur_mod = relay.Module.from_expr(func)
cur_mod = CanonicalizeOps()(cur_mod)
cur_mod = AlterOpLayout("FTVMConvertOpLayout")(cur_mod)
cur_mod = FoldConstant()(cur_mod)
return cur_mod['main']
2 changes: 2 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,8 @@ class RelayBuildModule : public runtime::ModuleNode {
} else {
relay_module = seq(relay_module);
}
auto func1 = relay_module->Lookup("main");
LOG(INFO) << AsText(func1, false);

// Handle heterogeneous compilation.
transform::PassContext pass_ctx = PassContext::Current();
Expand Down
40 changes: 30 additions & 10 deletions src/relay/pass/alter_op_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,9 @@ std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer(
// Call registered FTVMAlterOpLayout of an op
// Returns the altered expression
Call CallAlter(const Call& ref_call,
const std::vector<Expr>& new_args) {
static auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>("FTVMAlterOpLayout");
const std::vector<Expr>& new_args,
const std::string& alter_map_attr_name) {
auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>(alter_map_attr_name);
Op op = Downcast<Op>(ref_call->op);

Expr new_e;
Expand All @@ -213,9 +214,10 @@ Call CallAlter(const Call& ref_call,
return GetRef<Call>(new_call);
}

Expr AlterOpLayoutRewrite(const Call &ref_call,
const Array<Expr> &new_args,
const NodeRef& ctx) {
Expr Rewriter(const Call &ref_call,
const Array<Expr> &new_args,
const NodeRef& ctx,
const std::string& alter_map_attr_name) {
std::vector<LayoutAlternatedExpr> inputs;
std::vector<Expr> normal_new_args;
Array<Array<IndexExpr> > input_shapes;
Expand Down Expand Up @@ -289,7 +291,7 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
}

// new_op = alter(op)
Call new_call = CallAlter(ref_call, normal_new_args);
Call new_call = CallAlter(ref_call, normal_new_args, alter_map_attr_name);

// new_in2, new_out = op.infer(new_in)
if (new_call->op->IsInstance<OpNode>()) {
Expand Down Expand Up @@ -353,26 +355,44 @@ Expr AlterOpLayoutRewrite(const Call &ref_call,
}
}

Expr AlterOpLayoutRewrite(const Call &ref_call,
const Array<Expr> &new_args,
const NodeRef& ctx) {
return Rewriter(ref_call, new_args, ctx, "FTVMAlterOpLayout");
}

Expr ConvertOpLayoutRewrite(const Call &ref_call,
const Array<Expr> &new_args,
const NodeRef& ctx) {
return Rewriter(ref_call, new_args, ctx, "FTVMConvertOpLayout");
}

// Limiations:
// 1. the altered op should have the same number of arguments as the previous one
// 2. do not support nested tuple arguments
Expr AlterOpLayout(const Expr& expr) {
Expr AlterOpLayout(const Expr& expr, const std::string& alter_map_attr_name) {
TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>());
auto fcontext = [&](const Call& call) -> NodeRef{
return transformMemorizer;
};

return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext);
if (alter_map_attr_name == "FTVMAlterOpLayout") {
return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext);
} else if (alter_map_attr_name == "FTVMConvertOpLayout") {
return ForwardRewrite(expr, ConvertOpLayoutRewrite, fcontext);
}
LOG(FATAL) << "AlterOpLayout supports only 2 attr names - FTVMAlterOpLayout/FTVMConvertOpLayout";
return Expr();
}

} // namespace alter_op_layout

namespace transform {

Pass AlterOpLayout() {
Pass AlterOpLayout(const std::string& alter_map_attr_name) {
runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func =
[=](Function f, Module m, PassContext pc) {
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f));
return Downcast<Function>(relay::alter_op_layout::AlterOpLayout(f, alter_map_attr_name));
};
return CreateFunctionPass(pass_func, 3, "AlterOpLayout",
{ir::StringImm::make("InferType")});
Expand Down
89 changes: 46 additions & 43 deletions tests/python/frontend/tflite/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def convert_to_list(x):
x = [x]
return x

def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target='llvm',
def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1,
target='llvm -mcpu=cascadelake',
out_names=None):
""" Generic function to compile on relay and execute on tvm """
try:
Expand All @@ -72,6 +73,8 @@ def run_tvm_graph(tflite_model_buf, input_data, input_node, num_output=1, target
mod, params = relay.frontend.from_tflite(tflite_model,
shape_dict=shape_dict,
dtype_dict=dtype_dict)
mod = relay.transform.ConvertLayout('NCHW')(mod)
import tvm
with relay.build_config(opt_level=3):
graph, lib, params = relay.build(mod, target, params=params)

Expand Down Expand Up @@ -1101,59 +1104,59 @@ def test_forward_ssd_mobilenet_v1():
# Main
# ----
if __name__ == '__main__':
# BatchToSpaceND
test_forward_batch_to_space_nd()
# # BatchToSpaceND
# test_forward_batch_to_space_nd()

# SpaceToBatchND
test_forward_space_to_batch_nd()
# # SpaceToBatchND
# test_forward_space_to_batch_nd()

# Split
test_forward_split()
# # Split
# test_forward_split()

# Transpose
test_forward_transpose()
# # Transpose
# test_forward_transpose()

# Cast
test_forward_cast()
# # Cast
# test_forward_cast()

# Tile
test_forward_tile()
# # Tile
# test_forward_tile()

# Transforms
test_forward_concatenation()
test_forward_pad()
test_forward_pack()
test_forward_reshape()
test_all_resize()
test_forward_squeeze()
# # Transforms
# test_forward_concatenation()
# test_forward_pad()
# test_forward_pack()
# test_forward_reshape()
# test_all_resize()
# test_forward_squeeze()

# NN
test_forward_convolution()
test_forward_logistic()
test_forward_pooling()
test_forward_softmax()
test_forward_tanh()
test_forward_relu()
test_forward_prelu()
test_forward_fully_connected()
# # NN
# test_forward_convolution()
# test_forward_logistic()
# test_forward_pooling()
# test_forward_softmax()
# test_forward_tanh()
# test_forward_relu()
# test_forward_prelu()
# test_forward_fully_connected()

# Elemwise
test_all_elemwise()
# # Elemwise
# test_all_elemwise()

# Zeros Like
test_forward_zeros_like()
# # Zeros Like
# test_forward_zeros_like()

# Reduce
test_all_reduce()
# # Reduce
# test_all_reduce()

# End to End
# # End to End
test_forward_mobilenet_v1()
test_forward_mobilenet_v2()
# test_forward_mobilenet_v2()
test_forward_inception_v3_net()
test_forward_inception_v4_net()
test_forward_ssd_mobilenet_v1()
# test_forward_inception_v4_net()
# test_forward_ssd_mobilenet_v1()

# End to End quantized
test_forward_qnn_inception_v1_net()
test_forward_qnn_mobilenet_v1_net()
test_forward_qnn_mobilenet_v2_net()
# # End to End quantized
# test_forward_qnn_inception_v1_net()
# test_forward_qnn_mobilenet_v1_net()
# test_forward_qnn_mobilenet_v2_net()

0 comments on commit 1fb5cc4

Please sign in to comment.