Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AlterLayout] NCHWc upsampling, fix depthwise conv #2806

Merged
merged 2 commits into from
Mar 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,28 @@ def schedule_contrib_conv2d_NCHWc(attrs, outs, target):

reg.register_pattern("nn.contrib_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of depthwise conv2d NCHWc"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype

out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
data_layout, out_layout, out_dtype)
return [out]

@reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
"""Schedule definition of contrib_conv2d_NCHWc"""
with target:
return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)

reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)
64 changes: 64 additions & 0 deletions python/tvm/relay/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,70 @@ def contrib_conv2d_nchwc(data,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)

def contrib_depthwise_conv2d_nchwc(data,
kernel,
strides=(1, 1),
padding=(0, 0),
dilation=(1, 1),
groups=1,
channels=None,
kernel_size=None,
data_layout="NCHW8c",
kernel_layout="OIHW",
out_layout="",
out_dtype=""):
r"""Variant of 2D depthwise convolution.
This operator takes the weight as the depthwise convolution kernel
and depthwise convolves it with data to produce an output, following a specialized
NCHWc data layout.
Parameters
----------
data : tvm.relay.Expr
The input data to the operator.
kernel : tvm.relay.Expr
The kernel expressions.
strides : tuple of int, optional
The strides of convoltution.
padding : tuple of int, optional
The padding of convolution on both sides of inputs before convolution.
dilation : tuple of int, optional
Specifies the dilation rate to be used for dilated convolution.
groups : int, optional
Number of groups for grouped convolution.
channels : int, optional
Number of output channels of this convolution.
kernel_size : tuple of int, optional
The spatial of the convolution kernel.
data_layout : str, optional
Layout of the input.
kernel_layout : str, optional
Layout of the weight.
out_layout : str, optional
Layout of the output, by default, out_layout is the same as data_layout
out_dtype : str, optional
Specifies the output data type for mixed precision conv2d.
Returns
-------
result : tvm.relay.Expr
The computed result.
"""
return _make.contrib_depthwise_conv2d_NCHWc(data, kernel, strides, padding, dilation,
groups, channels, kernel_size, data_layout,
kernel_layout, out_layout, out_dtype)

def contrib_conv2d_winograd_weight_transform(weight,
tile_size):
Expand Down
52 changes: 52 additions & 0 deletions src/relay/op/nn/convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,5 +582,57 @@ RELAY_REGISTER_OP("nn.contrib_conv2d_NCHWc")
Conv2DInferCorrectLayout<Conv2DAttrs>);


// Positional relay function to create depthwise conv2d NCHWc operator
// used by frontend FFI.
Expr MakeDepthwiseConv2DNCHWc(Expr data,
Expr kernel,
Array<IndexExpr> strides,
Array<IndexExpr> padding,
Array<IndexExpr> dilation,
int groups,
IndexExpr channels,
Array<IndexExpr> kernel_size,
std::string data_layout,
std::string kernel_layout,
std::string out_layout,
DataType out_dtype) {
auto attrs = make_node<Conv2DAttrs>();
attrs->strides = std::move(strides);
attrs->padding = std::move(padding);
attrs->dilation = std::move(dilation);
attrs->groups = groups;
attrs->channels = channels;
attrs->kernel_size = std::move(kernel_size);
attrs->data_layout = std::move(data_layout);
attrs->kernel_layout = std::move(kernel_layout);
attrs->out_layout = std::move(out_layout);
attrs->out_dtype = std::move(out_dtype);
static const Op& op = Op::Get("nn.contrib_depthwise_conv2d_NCHWc");
return CallNode::make(op, {data, kernel}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op.nn._make.contrib_depthwise_conv2d_NCHWc")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 12>(MakeDepthwiseConv2DNCHWc, args, rv);
});


RELAY_REGISTER_OP("nn.contrib_depthwise_conv2d_NCHWc")
.describe(R"code(Compute conv2d with NCHWc data layout. Only supports NCHW layout.
- **data**: Input is 5D packed tensor.
- **weight**: 6D packed tensor.
- **out**: Output is 5D packed tensor
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.DepthwiseConv2D")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_support_level(10)
.add_type_rel("Conv2D", Conv2DRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
Conv2DInferCorrectLayout<Conv2DAttrs>);


} // namespace relay
} // namespace tvm
35 changes: 32 additions & 3 deletions src/relay/op/nn/upsampling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,31 @@ namespace relay {

TVM_REGISTER_NODE_TYPE(UpSamplingAttrs);

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(
const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<Array<IndexExpr>> &old_in_shapes) {
// NOTE: Discard "const" qualifier here.
T *params = const_cast<T*>(attrs.as<T>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this is necessary? It's quite counterintuitive why this kind of thing is needed, since the desired behaviour is essentially "if input layout is NCHWc, then expect input in NCHWc, perform compute, and output in NCHWc" right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is same to: https://github.com/dmlc/tvm/blob/master/src/relay/op/nn/pooling.cc#L22

In order to invoke AlterLayout Pass, we need FInferCorrectLayout attr, which is this function. This is part of logic of current AlterLayoutPass.


if (new_in_layouts.defined()) {
CHECK_EQ(new_in_layouts.size(), 1);

Layout raw_layout(params->layout);
Layout input = new_in_layouts[0];
if (input.IndexOf(LayoutAxis::Get('W')) == raw_layout.IndexOf(LayoutAxis::Get('W')) &&
input.IndexOf(LayoutAxis::Get('H')) == raw_layout.IndexOf(LayoutAxis::Get('H')) &&
!input.Contains(LayoutAxis::Get('w')) && !input.Contains(LayoutAxis::Get('h'))) {
params->layout = input.name(); // modify self to follow the input layout
}
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
}

bool UpSamplingRel(const Array<Type>& types,
int num_inputs,
const Attrs& attrs,
Expand Down Expand Up @@ -91,6 +116,8 @@ RELAY_REGISTER_OP("nn.upsampling")
.add_argument("data", "Tensor", "The input tensor.")
.set_support_level(2)
.add_type_rel("UpSampling", UpSamplingRel)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout",
UpsamplingInferCorrectLayout<UpSamplingAttrs>)
.set_attr<TOpPattern>("TOpPattern", kInjective)
.set_attr<FTVMCompute>(
"FTVMCompute", [](const Attrs& attrs,
Expand All @@ -101,14 +128,16 @@ RELAY_REGISTER_OP("nn.upsampling")
CHECK(uattrs != nullptr);
auto out_tt = out_type.as<TensorTypeNode>();
CHECK(out_tt) << "expected a tensor type: " << out_type;
CHECK(uattrs->layout == "NCHW" || uattrs->layout == "NHWC")
const auto layout = uattrs->layout;
const auto base_layout = layout.substr(0, 4);
CHECK(base_layout == "NCHW" || layout == "NHWC")
<< "unknown layout: " << uattrs->layout;

Array<HalideIR::Expr> oshape;
if (uattrs->layout == "NCHW") {
if (base_layout == "NCHW") {
antinucleon marked this conversation as resolved.
Show resolved Hide resolved
oshape.push_back(out_tt->shape[2]);
oshape.push_back(out_tt->shape[3]);
} else if (uattrs->layout == "NHWC") {
} else if (layout == "NHWC") {
oshape.push_back(out_tt->shape[1]);
oshape.push_back(out_tt->shape[2]);
}
Expand Down
2 changes: 1 addition & 1 deletion tests/lint/pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ single-line-if-stmt=no
no-space-check=trailing-comma,dict-separator

# Maximum number of lines in a module
max-module-lines=1000
max-module-lines=1500

# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1
# tab).
Expand Down
46 changes: 46 additions & 0 deletions tests/python/relay/test_pass_alter_op_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,51 @@ def expected():

assert(alpha_equal(a, b))


def test_alter_layout_nchw_upsamping_op():
"""Test upsamping operators """
def before():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var('weight', shape=(32, 32, 3, 3))
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1))
y = relay.nn.upsampling(y, scale=2)
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2))
y = relay.Function(free_vars(y), y)
return y

@register_alter_op_layout("nn.conv2d", level=108)
def alter_conv2d(attrs, inputs, tinfos):
data, weight = inputs
new_attrs = dict(attrs)
new_attrs['data_layout'] = 'NCHW16c'
return relay.nn.conv2d(data, weight, **new_attrs)

def expected():
x = relay.var("x", shape=(1, 32, 28, 28))
weight = relay.var("weight")
x = relay.layout_transform(x, "NCHW", "NCHW16c")
y = relay.nn.conv2d(x, weight, channels=32, kernel_size=(3, 3), padding=(1, 1),
data_layout="NCHW16c")
y = relay.nn.upsampling(y, scale=2, layout="NCHW16c")
y = relay.nn.avg_pool2d(y, pool_size=(2, 2), strides=(2, 2), layout='NCHW16c')
y = relay.layout_transform(y, "NCHW16c", "NCHW")
y = relay.Function(free_vars(y), y)
return y

a = before()
a = infer_type(a)
a = canonicalize_ops(a)
a = infer_type(a)

a = alter_op_layout(a)
a = infer_type(a)

b = expected()
b = infer_type(b)

assert(alpha_equal(a, b))


if __name__ == "__main__":
test_alter_op()
test_alter_return_none()
Expand All @@ -420,3 +465,4 @@ def expected():
test_alter_layout_broadcast_op()
test_alter_layout_scalar()
test_alter_layout_concatenate()
test_alter_layout_nchw_upsamping_op()
3 changes: 3 additions & 0 deletions tests/scripts/task_python_topi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,8 @@ export PYTHONPATH=python:topi/python
make cython || exit -1
make cython3 || exit -1

rm -rf python/tvm/*.pyc python/tvm/*/*.pyc python/tvm/*/*/*.pyc
rm -rf topi/python/topi/*.pyc topi/python/topi/*/*.pyc topi/python/topi/*/*/*.pyc topi/python/topi/*/*/*/*.pyc

python -m nose -v topi/tests/python || exit -1
python3 -m nose -v topi/tests/python || exit -1
49 changes: 47 additions & 2 deletions topi/include/topi/image/resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,45 @@ inline Tensor resize_nearest_neighbor_nchw(const Tensor& input,
}, name, tag);
}

/*!
* \brief Resize given tensor to given shape using nearest neighbour for NCHWc
*
* \param input The input tensor.
* \param shape Output shape to resize to.
* \param align_corners To preserve centers of 4 corner pixels
* \param name Name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor resized to given shape
*/
inline Tensor resize_nearest_neighbor_nchwc(const Tensor& input,
const Array<Expr>& shape,
bool align_corners = false,
std::string name = "tensor",
std::string tag = kInjective) {
Array<Expr> out_shape;
out_shape.push_back(input->shape[0]);
out_shape.push_back(input->shape[1]);
out_shape.push_back(shape[0]);
out_shape.push_back(shape[1]);
out_shape.push_back(input->shape[4]);

Expr h_ratio = shape[0] / input->shape[2];
Expr w_ratio = shape[1] / input->shape[3];

return compute(
out_shape, [&](const Array<Var>& indices) {
Array<Expr> idx;
idx.push_back(indices[0]);
idx.push_back(indices[1]);
idx.push_back(indices[2] / h_ratio);
idx.push_back(indices[3] / w_ratio);
idx.push_back(indices[4]);

return input(idx);
}, name, tag);
}

/*!
* \brief Resize given tensor to given shape using nearest neighbour
*
Expand All @@ -153,11 +192,17 @@ inline Tensor resize_nearest_neighbor(const Tensor& input,
std::string name = "tensor",
std::string tag = kInjective) {
CHECK_EQ(align_corners, false) << "Align corners not supported for nearest neighbour";

auto base_layout = layout.substr(0, 4);
if (layout == "NHWC") {
return resize_nearest_neighbor_nhwc(input, shape, align_corners);
} else {
} else if (layout == "NCHW") {
return resize_nearest_neighbor_nchw(input, shape, align_corners);
} else if (base_layout == "NCHW") {
// NCHWc
return resize_nearest_neighbor_nchwc(input, shape, align_corners);
} else {
LOG(FATAL) << "Unknown layout: " << layout;
return Tensor();
}
}

Expand Down
4 changes: 2 additions & 2 deletions topi/python/topi/nn/depthwise_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
5-D with shape [batch, in_channel_chunk, in_height, in_width, in_channel_block]
Filter : tvm.Tensor
4-D with shape [out_channel_chunk, filter_height, filter_width, out_channel_block]
6-D with shape [out_channel_chunk, 1, filter_height, filter_width, 1, out_channel_block]
In NCHWc depthwise convolution,
we group kernel's in_channel and channel_multiplier together then do the tiling.
Expand All @@ -317,6 +317,6 @@ def depthwise_conv2d_NCHWc(Input, Filter, stride, padding, dilation,
Returns
-------
Output : tvm.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
5-D with shape [batch, out_channel_chunk, out_height, out_width, out_channel_block]
"""
raise ValueError("missing register for topi.nn.depthwise_conv2d_NCHWc")
4 changes: 2 additions & 2 deletions topi/python/topi/nn/upsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def upsampling(data, scale, layout="NCHW", method='NEAREST_NEIGHBOR'):
4-D with shape [batch, channel, in_height*scale, in_width*scale]
or [batch, in_height*scale, in_width*scale, channel]
"""

if layout == "NCHW":
base_layout = layout[0:4]
if base_layout == "NCHW":
out_shape = (simplify(data.shape[2] * scale), simplify(data.shape[3] * scale))
elif layout == "NHWC":
out_shape = (simplify(data.shape[1] * scale), simplify(data.shape[2] * scale))
Expand Down
Loading