Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] ConvertLayout pass. #4335

Merged
merged 1 commit into from
Dec 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions include/tvm/relay/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <tvm/build_module.h>
#include <tvm/relay/type.h>
#include <tvm/relay/expr.h>
#include <string>

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -132,6 +133,22 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc<
const Array<Expr>& args,
const Array<Tensor>& tinfos)>;

/*!
* \brief Convert the layout of operators or replace the
* operator with other expressions. This function will be invoked
* in ConvertLayout pass.
* \param attrs The attribute of the original node.
* \param inputs The input symbols of the original node.
* \param tinfos An array of placeholders, use for getting the inferred shape
* and dtype of the inputs.
* \param desired_layout The desired layout.
* \return new_expr The modified expression.
*/
using FTVMConvertOpLayout = runtime::TypedPackedFunc<
Expr(const Attrs& attrs,
const Array<Expr>& args,
const Array<Tensor>& tinfos,
const std::string& desired_layout)>;
/*!
* \brief Legalizes an expression with another expression. This function will be
* invoked in Legalize pass. It is a target-dependent pass.
Expand Down
20 changes: 20 additions & 0 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,26 @@ TVM_DLL Pass CanonicalizeOps();
*/
TVM_DLL Pass AlterOpLayout();

/*!
* \brief Given a dest layout, this pass transforms the expr such that most of the ops input data
* layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one
* at the start and one at the end.
*
* This pass is not a part of relay.build and is expected to be called between framework-relay
* parser and relay.build call. This is very helpful for hardware backends that support/prefer only
* type of data layout.
*
* RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
*
* This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define new
* layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout
* using the InferCorrectLayout infrastructure.
*
* \param desired_layout The desired layout.
* \return The pass.
*/
TVM_DLL Pass ConvertLayout(const std::string& desired_layout);

/*!
* \brief Legalizes an expr with another expression.
* \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function.
Expand Down
41 changes: 41 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,47 @@ def legalize_conv2d(attrs, inputs, types):
"""
return topi.nn.conv2d_legalize(attrs, inputs, types)


@reg.register_convert_op_layout("nn.conv2d")
def convert_conv2d(attrs, inputs, tinfos, desired_layout):
"""Convert Layout pass registration for conv2d op.

Parameters
----------
attrs : tvm.attrs.Attrs
Attributes of current convolution
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layout : str
The desired layout

Returns
-------
result : tvm.relay.Expr
The transformed expr
"""

from tvm import relay
data_layout = attrs['data_layout']
kernel_layout = attrs['kernel_layout']
data, weight = inputs
assert desired_layout == 'NCHW', \
"Currently only transformation to NCHW layout is supported."
if desired_layout == 'NCHW':
new_attrs = dict(attrs)
new_attrs['data_layout'] = desired_layout
new_attrs['kernel_layout'] = 'OIHW'

if data_layout == 'NHWC' and kernel_layout == 'HWIO':
# Convert (NHWC, HWIO) to (NCHW, OIHW)
return relay.nn.conv2d(data, weight, **new_attrs)
if data_layout == 'NHWC' and kernel_layout == 'HWOI':
# Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d.
return relay.nn.conv2d(data, weight, **new_attrs)
return None

reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)


Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10):
return register(op_name, "FTVMAlterOpLayout", alter_layout, level)


def register_convert_op_layout(op_name, convert_layout=None, level=10):
"""Register convert op layout function for an op
Parameters
----------
op_name : str
The name of the operator
convert_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr
The function for changing the layout or replacing the operator
level : int
The priority level
"""
return register(op_name, "FTVMConvertOpLayout", convert_layout, level)


def register_legalize(op_name, legal_op=None, level=10):
"""Register legal transformation function for an op
Expand Down
28 changes: 28 additions & 0 deletions python/tvm/relay/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,34 @@ def AlterOpLayout():
return _transform.AlterOpLayout()


def ConvertLayout(desired_layout):
""" Given a dest layout, this pass transforms the expr such that most of the ops input data
layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms,
one at the start and one at the end.
This pass is not a part of relay.build and is expected to be called between framework-relay
parser and relay.build call. This is very helpful for hardware backends that support/prefer only
type of data layout.
RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009
This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define
new layouts for conv2d ops for now. Most of the other operators try to adapt to their input
layout using the InferCorrectLayout infrastructure.
Parameters
----------
desired_layout : str
The desired layout for the transformed expr.
Returns
-------
pass: FunctionPass
The pass.
"""
return _transform.ConvertLayout(desired_layout)


def Legalize(legalize_map_attr_name="FTVMLegalize"):
"""Legalizes an expression with another expression.
This pass can be used to replace an expr with another expr for target
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/annotation/annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
#include <tvm/relay/op_attr_types.h>
#include <topi/elemwise.h>

#include "../../pass/alter_op_layout.h"
#include "../../pass/infer_layout_util.h"
#include "../type_relations.h"

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/device_copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#include <tvm/relay/op_attr_types.h>

#include "type_relations.h"
#include "../pass/alter_op_layout.h"
#include "../pass/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
#include <tvm/relay/attrs/memory.h>

#include "../op_common.h"
#include "../../pass/alter_op_layout.h"
#include "../../pass/infer_layout_util.h"
#include "../type_relations.h"

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/bitserial.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include <tvm/relay/attrs/bitserial.h>
#include <tvm/relay/op.h>

#include "../../pass/alter_op_layout.h"
#include "../../pass/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h>
#include <vector>

#include "../../pass/alter_op_layout.h"
#include "../../pass/infer_layout_util.h"
#include "../op_common.h"
#include "convolution.h"

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <topi/nn/flatten.h>
#include <vector>
#include "../type_relations.h"
#include "../../pass/alter_op_layout.h"
#include "../../pass/infer_layout_util.h"
#include "../op_common.h"
#include "nn.h"

Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h>
#include <topi/nn/pooling.h>
#include <vector>
#include "../../pass/alter_op_layout.h"
#include "../../pass/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/nn/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/relay/attrs/nn.h>
#include <vector>

#include "../../pass/alter_op_layout.h"
#include "../../pass/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/op_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include <string>
#include <unordered_map>
#include "type_relations.h"
#include "../pass/alter_op_layout.h"
#include "../pass/infer_layout_util.h"

namespace tvm {
namespace relay {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#include <vector>
#include "../op_common.h"
#include "../../../arithmetic/compute_expr.h"
#include "../../pass/alter_op_layout.h"
#include "../../pass/infer_layout_util.h"
#include "../../pass/pattern_util.h"
#include "transform.h"

Expand Down
Loading